Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

0
src/core/__init__.py Normal file
View File

View File

@@ -0,0 +1,272 @@
"""
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
- api_format_metadata: 定义格式的元数据(别名、默认路径)
- src/formats/: 定义格式的协议实现(解析、转换、验证)
使用方式:
# 解析格式别名
from src.core.api_format_metadata import resolve_api_format
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
# 获取格式协议
from src.core.api_format_metadata import get_format_protocol
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from functools import lru_cache
from types import MappingProxyType
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
from .enums import APIFormat
@dataclass(frozen=True)
class ApiFormatDefinition:
"""
描述一个 API 格式的所有通用信息。
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
- default_path: 上游默认请求路径(如 /v1/messages可通过 Endpoint.custom_path 覆盖
- path_prefix: 本站路径前缀(如 /claude, /openai为空表示无前缀
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
"""
api_format: APIFormat
aliases: Sequence[str] = field(default_factory=tuple)
default_path: str = "/" # 上游默认请求路径
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
auth_header: str = "Authorization"
auth_type: str = "bearer" # "bearer" or "header"
def iter_aliases(self) -> Iterable[str]:
"""返回大小写统一后的别名集合,包含枚举名本身。"""
yield normalize_alias_value(self.api_format.value)
for alias in self.aliases:
normalized = normalize_alias_value(alias)
if normalized:
yield normalized
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
APIFormat.CLAUDE: ApiFormatDefinition(
api_format=APIFormat.CLAUDE,
aliases=("claude", "anthropic", "claude_compatible"),
default_path="/v1/messages",
path_prefix="", # 本站路径前缀,可配置如 "/claude"
auth_header="x-api-key",
auth_type="header",
),
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
api_format=APIFormat.CLAUDE_CLI,
aliases=("claude_cli", "claude-cli"),
default_path="/v1/messages",
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
auth_header="authorization",
auth_type="bearer",
),
APIFormat.OPENAI: ApiFormatDefinition(
api_format=APIFormat.OPENAI,
aliases=(
"openai",
"deepseek",
"grok",
"moonshot",
"zhipu",
"qwen",
"baichuan",
"minimax",
"openai_compatible",
),
default_path="/v1/chat/completions",
path_prefix="", # 本站路径前缀,可配置如 "/openai"
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.OPENAI_CLI: ApiFormatDefinition(
api_format=APIFormat.OPENAI_CLI,
aliases=("openai_cli", "responses"),
default_path="/responses",
path_prefix="",
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.GEMINI: ApiFormatDefinition(
api_format=APIFormat.GEMINI,
aliases=("gemini", "google", "vertex"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
auth_header="x-goog-api-key",
auth_type="header",
),
APIFormat.GEMINI_CLI: ApiFormatDefinition(
api_format=APIFormat.GEMINI_CLI,
aliases=("gemini_cli", "gemini-cli"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 与 GEMINI 共享入口
auth_header="x-goog-api-key",
auth_type="header",
),
}
# 对外只暴露只读视图,避免被随意修改
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
"""获取指定格式的定义,不存在时抛出 KeyError。"""
return API_FORMAT_DEFINITIONS[api_format]
def list_api_format_definitions() -> List[ApiFormatDefinition]:
"""返回所有定义的浅拷贝列表,供遍历使用。"""
return list(API_FORMAT_DEFINITIONS.values())
def build_alias_lookup() -> Dict[str, APIFormat]:
"""
构建 alias -> APIFormat 的查找表。
每次调用都会返回新的 dict避免可变全局引发并发问题。
"""
lookup: MutableMapping[str, APIFormat] = {}
for definition in API_FORMAT_DEFINITIONS.values():
for alias in definition.iter_aliases():
lookup.setdefault(alias, definition.api_format)
return dict(lookup)
def get_default_path(api_format: APIFormat) -> str:
"""
获取该格式的上游默认请求路径。
可通过 Endpoint.custom_path 覆盖。
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
return definition.default_path if definition else "/"
def get_local_path(api_format: APIFormat) -> str:
"""
获取该格式的本站入口路径。
本站入口路径 = path_prefix + default_path
例如path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
prefix = definition.path_prefix or ""
return prefix + definition.default_path
return "/"
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
"""
获取该格式的认证配置。
Returns:
(auth_header, auth_type) 元组
- auth_header: 认证头名称
- auth_type: "bearer""header"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
return definition.auth_header, definition.auth_type
return "Authorization", "bearer"
@lru_cache(maxsize=1)
def _alias_lookup_cache() -> Dict[str, APIFormat]:
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
return build_alias_lookup()
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
"""根据别名查找 APIFormat找不到时返回 None。"""
if not value:
return None
normalized = normalize_alias_value(value)
if not normalized:
return None
return _alias_lookup_cache().get(normalized)
def resolve_api_format(
value: Union[str, APIFormat, None],
default: Optional[APIFormat] = None,
) -> Optional[APIFormat]:
"""
将任意字符串/枚举值解析为 APIFormat。
Args:
value: 可以是 APIFormat 或任意字符串/别名
default: 未解析成功时返回的默认值
"""
if isinstance(value, APIFormat):
return value
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return default
upper = stripped.upper()
if upper in APIFormat.__members__:
return APIFormat[upper]
alias = resolve_api_format_alias(stripped)
if alias:
return alias
return default
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
"""
注册或覆盖 API 格式定义,允许运行时扩展。
Args:
definition: 要注册的定义
override: 若目标枚举已存在,是否允许覆盖
"""
existing = _DEFINITIONS.get(definition.api_format)
if existing and not override:
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
_DEFINITIONS[definition.api_format] = definition
_refresh_metadata_cache()
def _refresh_metadata_cache():
"""更新别名缓存,供注册函数调用。"""
_alias_lookup_cache.cache_clear()
def normalize_alias_value(value: str) -> str:
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
if value is None:
return ""
text = value.strip().lower()
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
text = re.sub(r"[^a-z0-9]+", "_", text)
return text.strip("_")
# =============================================================================
# 格式判断工具
# =============================================================================
def is_cli_api_format(api_format: APIFormat) -> bool:
"""
判断是否为 CLI 透传格式。
Args:
api_format: APIFormat 枚举值
Returns:
True 如果是 CLI 格式
"""
from src.api.handlers.base.parsers import is_cli_format
return is_cli_format(api_format.value)

115
src/core/batch_committer.py Normal file
View File

@@ -0,0 +1,115 @@
"""
批量提交器 - 减少数据库 commit 次数,提升并发能力
核心思想:
- 非关键数据(监控、统计)不立即 commit
- 在后台定期批量 commit
- 关键数据(计费)仍然立即 commit
"""
import asyncio
from typing import Set
from src.core.logger import logger
from sqlalchemy.orm import Session
class BatchCommitter:
"""批量提交管理器"""
def __init__(self, interval_seconds: float = 1.0):
"""
Args:
interval_seconds: 批量提交间隔(秒)
"""
self.interval_seconds = interval_seconds
self._pending_sessions: Set[Session] = set()
self._lock = asyncio.Lock()
self._task = None
async def start(self):
"""启动后台批量提交任务"""
if self._task is None:
self._task = asyncio.create_task(self._batch_commit_loop())
logger.info(f"批量提交器已启动,间隔: {self.interval_seconds}s")
async def stop(self):
"""停止后台任务"""
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info("批量提交器已停止")
def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改"""
self._pending_sessions.add(session)
async def _batch_commit_loop(self):
"""后台批量提交循环"""
while True:
try:
await asyncio.sleep(self.interval_seconds)
await self._commit_all()
except asyncio.CancelledError:
# 关闭前提交所有待处理的
await self._commit_all()
raise
except Exception as e:
logger.error(f"批量提交出错: {e}")
async def _commit_all(self):
"""提交所有待处理的 Session"""
async with self._lock:
if not self._pending_sessions:
return
sessions_to_commit = list(self._pending_sessions)
self._pending_sessions.clear()
committed = 0
failed = 0
for session in sessions_to_commit:
try:
session.commit()
committed += 1
except Exception as e:
logger.error(f"提交 Session 失败: {e}")
try:
session.rollback()
except:
pass
failed += 1
if committed > 0:
logger.debug(f"批量提交完成: {committed} 个 Session")
if failed > 0:
logger.warning(f"批量提交失败: {failed} 个 Session")
# 全局单例
_batch_committer: BatchCommitter = None
def get_batch_committer() -> BatchCommitter:
"""获取全局批量提交器"""
global _batch_committer
if _batch_committer is None:
_batch_committer = BatchCommitter(interval_seconds=1.0)
return _batch_committer
async def init_batch_committer():
"""初始化并启动批量提交器"""
committer = get_batch_committer()
await committer.start()
async def shutdown_batch_committer():
"""关闭批量提交器"""
committer = get_batch_committer()
await committer.stop()

174
src/core/cache_service.py Normal file
View File

@@ -0,0 +1,174 @@
"""
缓存服务 - 统一的缓存抽象层
"""
import json
from datetime import timedelta
from typing import Any, Optional
from src.clients.redis_client import get_redis_client
from src.core.logger import logger
class CacheService:
"""缓存服务"""
@staticmethod
async def get(key: str) -> Optional[Any]:
"""
从缓存获取数据
Args:
key: 缓存键
Returns:
缓存的值,如果不存在则返回 None
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return None
value = await redis.get(key)
if value:
# 尝试 JSON 反序列化
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return None
except Exception as e:
logger.warning(f"缓存读取失败: {key} - {e}")
return None
@staticmethod
async def set(key: str, value: Any, ttl_seconds: int = 60) -> bool:
"""
设置缓存
Args:
key: 缓存键
value: 缓存值
ttl_seconds: 过期时间(秒),默认 60 秒
Returns:
是否设置成功
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
# JSON 序列化
if isinstance(value, (dict, list)):
value = json.dumps(value)
elif not isinstance(value, (str, bytes)):
value = str(value)
await redis.setex(key, ttl_seconds, value)
return True
except Exception as e:
logger.warning(f"缓存写入失败: {key} - {e}")
return False
@staticmethod
async def delete(key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
是否删除成功
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
await redis.delete(key)
return True
except Exception as e:
logger.warning(f"缓存删除失败: {key} - {e}")
return False
@staticmethod
async def exists(key: str) -> bool:
"""
检查缓存是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
return await redis.exists(key) > 0
except Exception as e:
logger.warning(f"缓存检查失败: {key} - {e}")
return False
# 缓存键前缀
class CacheKeys:
"""缓存键定义"""
# User 缓存TTL 60秒
USER_BY_ID = "user:id:{user_id}"
USER_BY_EMAIL = "user:email:{email}"
# API Key 缓存TTL 30秒
APIKEY_HASH = "apikey:hash:{key_hash}"
APIKEY_AUTH = "apikey:auth:{key_hash}" # 认证结果缓存
# Provider 配置缓存TTL 300秒
PROVIDER_BY_ID = "provider:id:{provider_id}"
ENDPOINT_BY_ID = "endpoint:id:{endpoint_id}"
API_KEY_BY_ID = "api_key:id:{api_key_id}"
@staticmethod
def user_by_id(user_id: str) -> str:
"""User ID 缓存键"""
return CacheKeys.USER_BY_ID.format(user_id=user_id)
@staticmethod
def user_by_email(email: str) -> str:
"""User Email 缓存键"""
return CacheKeys.USER_BY_EMAIL.format(email=email)
@staticmethod
def apikey_hash(key_hash: str) -> str:
"""API Key Hash 缓存键"""
return CacheKeys.APIKEY_HASH.format(key_hash=key_hash)
@staticmethod
def apikey_auth(key_hash: str) -> str:
"""API Key 认证结果缓存键"""
return CacheKeys.APIKEY_AUTH.format(key_hash=key_hash)
@staticmethod
def provider_by_id(provider_id: str) -> str:
"""Provider ID 缓存键"""
return CacheKeys.PROVIDER_BY_ID.format(provider_id=provider_id)
@staticmethod
def endpoint_by_id(endpoint_id: str) -> str:
"""Endpoint ID 缓存键"""
return CacheKeys.ENDPOINT_BY_ID.format(endpoint_id=endpoint_id)
@staticmethod
def api_key_by_id(api_key_id: str) -> str:
"""API Key ID 缓存键"""
return CacheKeys.API_KEY_BY_ID.format(api_key_id=api_key_id)

133
src/core/cache_utils.py Normal file
View File

@@ -0,0 +1,133 @@
"""
缓存工具类
提供同步缓存接口,用于不适合使用异步缓存的场景
"""
import threading
import time
from collections import OrderedDict
from typing import Any, Dict, Optional
class SyncLRUCache:
"""
同步 LRU 缓存(带 TTL 和线程安全)
用于需要同步访问的场景,如 ModelMapperMiddleware
"""
def __init__(self, max_size: int = 1000, ttl: int = 300) -> None:
"""
初始化缓存
Args:
max_size: 最大缓存条目数
ttl: 过期时间(秒)
"""
self._cache: OrderedDict = OrderedDict()
self._expiry: Dict[Any, float] = {}
self.max_size = max_size
self.ttl = ttl
self._lock = threading.RLock()
def _is_expired(self, key: Any) -> bool:
"""检查 key 是否过期(调用者需确保已持有锁)"""
if key in self._expiry:
return time.time() > self._expiry[key]
return False
def _delete_key(self, key: Any) -> None:
"""删除 key调用者需确保已持有锁"""
if key in self._cache:
del self._cache[key]
if key in self._expiry:
del self._expiry[key]
def get(self, key: Any, default: Any = None) -> Any:
"""获取缓存值"""
with self._lock:
if key not in self._cache:
return default
if self._is_expired(key):
self._delete_key(key)
return default
self._cache.move_to_end(key)
return self._cache[key]
def set(self, key: Any, value: Any, ttl: Optional[int] = None) -> None:
"""设置缓存值"""
with self._lock:
if ttl is None:
ttl = self.ttl
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = value
self._expiry[key] = time.time() + ttl
while len(self._cache) > self.max_size:
oldest = next(iter(self._cache))
self._delete_key(oldest)
def delete(self, key: Any) -> None:
"""删除缓存值"""
with self._lock:
self._delete_key(key)
def clear(self) -> None:
"""清空缓存"""
with self._lock:
self._cache.clear()
self._expiry.clear()
def __contains__(self, key: Any) -> bool:
"""检查 key 是否存在"""
with self._lock:
if key not in self._cache:
return False
if self._is_expired(key):
self._delete_key(key)
return False
return True
def __getitem__(self, key: Any) -> Any:
"""获取缓存值(通过索引)"""
with self._lock:
if key not in self._cache:
raise KeyError(key)
if self._is_expired(key):
self._delete_key(key)
raise KeyError(key)
self._cache.move_to_end(key)
return self._cache[key]
def __setitem__(self, key: Any, value: Any) -> None:
"""设置缓存值(通过索引)"""
self.set(key, value)
def __delitem__(self, key: Any) -> None:
"""删除缓存值(通过索引)"""
self.delete(key)
def keys(self):
"""返回所有未过期的 key"""
with self._lock:
now = time.time()
return [
k for k in self._cache.keys() if k not in self._expiry or now <= self._expiry[k]
]
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self._lock:
return {
"size": len(self._cache),
"max_size": self.max_size,
"ttl": self.ttl,
}

168
src/core/context.py Normal file
View File

@@ -0,0 +1,168 @@
"""
统一的请求上下文
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
这确保了数据在各层之间传递时不会丢失。
使用方式:
1. Pipeline 层创建 RequestContext
2. 各层通过 context 访问和更新信息
3. Adapter 层使用 context 记录 Usage
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class RequestContext:
"""
请求上下文 - 贯穿整个请求生命周期
设计原则:
1. 在请求开始时创建,包含所有已知信息
2. 在请求执行过程中逐步填充 Provider 信息
3. 在请求结束时用于记录 Usage
"""
# ==================== 请求标识 ====================
request_id: str
# ==================== 认证信息 ====================
user: Any # User model
api_key: Any # ApiKey model
db: Any # Database session
# ==================== 请求信息 ====================
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
model: str # 用户请求的模型名
is_stream: bool = False
# ==================== 原始请求 ====================
original_headers: Dict[str, str] = field(default_factory=dict)
original_body: Dict[str, Any] = field(default_factory=dict)
# ==================== 客户端信息 ====================
client_ip: str = "unknown"
user_agent: str = ""
# ==================== 计时 ====================
start_time: float = field(default_factory=time.time)
# ==================== Provider 信息(请求执行后填充)====================
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# ==================== 模型映射信息 ====================
resolved_model: Optional[str] = None # 映射后的模型名
original_model: Optional[str] = None # 原始模型名(用于价格计算)
# ==================== 请求/响应头 ====================
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# ==================== 追踪信息 ====================
attempt_id: Optional[str] = None
# ==================== 能力需求 ====================
capability_requirements: Dict[str, bool] = field(default_factory=dict)
# 运行时计算的能力需求,来源于:
# 1. 用户 model_capability_settings
# 2. 用户 ApiKey.force_capabilities
# 3. 请求头 X-Require-Capability
# 4. 失败重试时动态添加
@classmethod
def create(
cls,
*,
db: Any,
user: Any,
api_key: Any,
api_format: str,
model: str,
is_stream: bool = False,
original_headers: Optional[Dict[str, str]] = None,
original_body: Optional[Dict[str, Any]] = None,
client_ip: str = "unknown",
user_agent: str = "",
request_id: Optional[str] = None,
) -> "RequestContext":
"""创建请求上下文"""
return cls(
request_id=request_id or str(uuid.uuid4()),
db=db,
user=user,
api_key=api_key,
api_format=api_format,
model=model,
is_stream=is_stream,
original_headers=original_headers or {},
original_body=original_body or {},
client_ip=client_ip,
user_agent=user_agent,
original_model=model, # 初始时原始模型等于请求模型
)
def update_provider_info(
self,
*,
provider_name: str,
provider_id: str,
endpoint_id: str,
provider_api_key_id: str,
resolved_model: Optional[str] = None,
) -> None:
"""更新 Provider 信息(请求执行后调用)"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.provider_api_key_id = provider_api_key_id
if resolved_model:
self.resolved_model = resolved_model
def update_headers(
self,
*,
request_headers: Optional[Dict[str, str]] = None,
response_headers: Optional[Dict[str, str]] = None,
) -> None:
"""更新请求/响应头"""
if request_headers:
self.provider_request_headers = request_headers
if response_headers:
self.provider_response_headers = response_headers
@property
def elapsed_ms(self) -> int:
"""计算已经过的时间(毫秒)"""
return int((time.time() - self.start_time) * 1000)
@property
def effective_model(self) -> str:
"""获取有效的模型名(映射后优先)"""
return self.resolved_model or self.model
@property
def billing_model(self) -> str:
"""获取计费模型名(原始模型优先)"""
return self.original_model or self.model
def to_metadata_dict(self) -> Dict[str, Any]:
"""转换为元数据字典(用于 Usage 记录)"""
return {
"api_format": self.api_format,
"provider": self.provider_name or "unknown",
"model": self.effective_model,
"original_model": self.billing_model,
"provider_id": self.provider_id,
"provider_endpoint_id": self.endpoint_id,
"provider_api_key_id": self.provider_api_key_id,
"provider_request_headers": self.provider_request_headers,
"provider_response_headers": self.provider_response_headers,
"attempt_id": self.attempt_id,
}

166
src/core/crypto.py Normal file
View File

@@ -0,0 +1,166 @@
"""
加密工具模块
提供API密钥的加密和解密功能
安全说明:
- 生产环境必须设置独立的 ENCRYPTION_KEY
- 加密密钥应独立于 JWT_SECRET_KEY避免密钥轮换问题
- 使用 PBKDF2 派生密钥时会使用应用级 salt
"""
import base64
import hashlib
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from ..config import config
from ..core.exceptions import DecryptionException
from src.core.logger import logger
class CryptoService:
"""
加密服务
提供对称加密功能,用于保护 Provider API Key 等敏感数据。
使用 FernetAES-128-CBC + HMAC-SHA256确保数据机密性和完整性。
"""
_instance = None
_cipher = None
_key_source: str = "unknown" # 记录密钥来源,用于调试
# 应用级 salt基于应用名称生成比硬编码更安全
# 注意:更改此值会导致所有已加密数据无法解密
APP_SALT = hashlib.sha256(b"aether-v1").digest()[:16]
def __new__(cls) -> "CryptoService":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self) -> None:
"""初始化加密服务"""
logger.info("初始化加密服务")
encryption_key = config.encryption_key
if not encryption_key:
if config.environment == "production":
raise ValueError(
"ENCRYPTION_KEY must be set in production! "
"Use 'python generate_keys.py' to generate a secure key."
)
# 开发环境:使用固定的开发密钥
logger.warning("[DEV] 未设置 ENCRYPTION_KEY使用开发环境默认密钥。")
encryption_key = "dev-encryption-key-do-not-use-in-production"
self._key_source = "development_default"
else:
self._key_source = "environment_variable"
# 派生 Fernet 密钥
key = self._derive_fernet_key(encryption_key)
self._cipher = Fernet(key)
logger.info(f"加密服务初始化成功 (key_source={self._key_source})")
def _derive_fernet_key(self, encryption_key: str) -> bytes:
"""
从密码/密钥派生 Fernet 兼容的密钥
Args:
encryption_key: 原始密钥字符串
Returns:
Fernet 兼容的 base64 编码密钥
"""
# 首先尝试直接作为 Fernet 密钥使用
try:
key_bytes = (
encryption_key.encode() if isinstance(encryption_key, str) else encryption_key
)
# 验证是否为有效的 Fernet 密钥32 字节 base64 编码)
Fernet(key_bytes)
return key_bytes
except Exception:
pass
# 不是有效的 Fernet 密钥,使用 PBKDF2 派生
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self.APP_SALT,
iterations=100000,
)
derived_key = kdf.derive(encryption_key.encode())
return base64.urlsafe_b64encode(derived_key)
def encrypt(self, plaintext: str) -> str:
"""
加密字符串
Args:
plaintext: 明文字符串
Returns:
加密后的字符串base64编码
"""
if not plaintext:
return plaintext
try:
encrypted = self._cipher.encrypt(plaintext.encode())
return base64.urlsafe_b64encode(encrypted).decode()
except Exception as e:
logger.error(f"Encryption failed: {e}")
raise ValueError("Failed to encrypt data")
def decrypt(self, ciphertext: str, silent: bool = False) -> str:
"""
解密字符串
Args:
ciphertext: 加密的字符串base64编码
silent: 是否静默模式(失败时不打印错误日志)
Returns:
解密后的明文字符串
Raises:
DecryptionException: 解密失败时抛出异常
"""
if not ciphertext:
return ciphertext
try:
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
decrypted = self._cipher.decrypt(encrypted)
return decrypted.decode()
except Exception as e:
if not silent:
logger.error(f"Decryption failed: {e}")
# 抛出自定义异常,方便在上层通过类型判断是否需要打印堆栈
raise DecryptionException(
message=f"解密失败: {str(e)}。可能原因: ENCRYPTION_KEY 已改变或数据已损坏。解决方案: 请在管理面板重新设置 Provider API Key。",
details={"original_error": str(e), "key_source": self._key_source},
)
def hash_api_key(self, api_key: str) -> str:
"""
对API密钥进行哈希用于查找
Args:
api_key: API密钥明文
Returns:
哈希后的值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
# 创建全局加密服务实例
crypto_service = CryptoService()

32
src/core/enums.py Normal file
View File

@@ -0,0 +1,32 @@
"""
统一的枚举定义
避免重复定义造成的不一致
"""
from enum import Enum
class APIFormat(Enum):
"""API格式枚举 - 决定请求/响应的处理方式"""
CLAUDE = "CLAUDE" # Claude API 格式
OPENAI = "OPENAI" # OpenAI API 格式
CLAUDE_CLI = "CLAUDE_CLI" # Claude CLI API 格式(使用 authorization: Bearer
OPENAI_CLI = "OPENAI_CLI" # OpenAI CLI/Responses API 格式(用于 Claude Code 等客户端)
GEMINI = "GEMINI" # Google Gemini API 格式
GEMINI_CLI = "GEMINI_CLI" # Gemini CLI API 格式
class UserRole(Enum):
"""用户角色枚举"""
ADMIN = "admin"
USER = "user"
class ProviderBillingType(Enum):
"""提供商计费类型"""
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
FREE_TIER = "free_tier" # 免费额度

675
src/core/exceptions.py Normal file
View File

@@ -0,0 +1,675 @@
"""
统一的异常处理和错误响应定义
安全说明:
- 生产环境不返回详细错误信息,避免信息泄露
- 使用错误 ID 关联日志,便于排查问题
- 开发环境可返回详细信息用于调试
"""
import asyncio
import re
import traceback
import uuid
from typing import Any, Dict, List, Optional
import httpx
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from ..config import config
from src.core.logger import logger
# Pydantic 错误消息中英文翻译映射
PYDANTIC_ERROR_TRANSLATIONS = {
# 字符串验证
r"String should have at least (\d+) characters?": r"字符串长度至少需要 \1 个字符",
r"String should have at most (\d+) characters?": r"字符串长度最多 \1 个字符",
r"string_too_short": "字符串长度不足",
r"string_too_long": "字符串长度超出限制",
# 必填字段
r"Field required": "此字段为必填项",
r"field required": "此字段为必填项",
r"missing": "缺少必填字段",
# 类型错误
r"Input should be a valid string": "输入应为有效的字符串",
r"Input should be a valid integer": "输入应为有效的整数",
r"Input should be a valid number": "输入应为有效的数字",
r"Input should be a valid boolean": "输入应为布尔值",
r"Input should be a valid email address": "输入应为有效的邮箱地址",
r"Input should be a valid list": "输入应为有效的列表",
r"Input should be a valid dictionary": "输入应为有效的字典",
# 数值验证
r"Input should be greater than (\d+)": r"数值应大于 \1",
r"Input should be greater than or equal to (\d+)": r"数值应大于或等于 \1",
r"Input should be less than (\d+)": r"数值应小于 \1",
r"Input should be less than or equal to (\d+)": r"数值应小于或等于 \1",
# 枚举验证
r"Input should be (.+)": r"输入应为 \1",
# 其他
r"value is not a valid email address": "邮箱地址格式无效",
r"invalid.*email": "邮箱地址格式无效",
r"Extra inputs are not permitted": "不允许额外的字段",
r"Value error, (.+)": r"\1", # 自定义验证器的错误直接使用
}
# 字段名中英文翻译映射
FIELD_NAME_TRANSLATIONS = {
"password": "密码",
"username": "用户名",
"email": "邮箱",
"role": "角色",
"quota_usd": "配额",
"name": "名称",
"title": "标题",
"content": "内容",
"ip_address": "IP地址",
"reason": "原因",
"ttl": "过期时间",
"enabled": "启用状态",
"fixed_limit": "固定限制",
"old_password": "旧密码",
"new_password": "新密码",
"allowed_providers": "允许的提供商",
"allowed_models": "允许的模型",
"rate_limit": "速率限制",
"expire_days": "过期天数",
"priority": "优先级",
"type": "类型",
"is_active": "激活状态",
"is_pinned": "置顶状态",
"start_time": "开始时间",
"end_time": "结束时间",
}
def translate_pydantic_error(error: Dict[str, Any]) -> str:
"""
将 Pydantic 验证错误翻译为中文
Args:
error: Pydantic 错误字典,包含 loc, msg, type 等字段
Returns:
翻译后的中文错误消息
"""
# 获取字段名
loc = error.get("loc", [])
field = str(loc[0]) if loc else ""
field_zh = FIELD_NAME_TRANSLATIONS.get(field, field)
# 获取错误消息
msg = error.get("msg", "验证失败")
# 尝试翻译错误消息
translated_msg = msg
for pattern, replacement in PYDANTIC_ERROR_TRANSLATIONS.items():
if re.search(pattern, msg, re.IGNORECASE):
translated_msg = re.sub(pattern, replacement, msg, flags=re.IGNORECASE)
break
# 组合字段名和错误消息
if field_zh:
return f"{field_zh}: {translated_msg}"
return translated_msg
def translate_pydantic_errors(errors: List[Dict[str, Any]]) -> str:
"""
翻译多个 Pydantic 验证错误
Args:
errors: Pydantic 错误列表
Returns:
翻译后的错误消息,多个错误用分号分隔
"""
if not errors:
return "请求数据验证失败"
translated = [translate_pydantic_error(e) for e in errors]
return "; ".join(translated)
# 延迟导入韧性管理器,避免循环导入
def get_resilience_manager():
try:
from ..core.resilience import resilience_manager
return resilience_manager
except ImportError:
return None
class ProxyException(HTTPException):
"""代理服务基础异常"""
def __init__(
self,
status_code: int,
error_type: str,
message: str,
details: Optional[Dict[str, Any]] = None,
):
self.error_type = error_type
self.message = message
self.details = details or {}
super().__init__(status_code=status_code, detail=message)
class ProviderException(ProxyException):
"""提供商相关异常"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
request_metadata: Optional[Any] = None,
**kwargs,
):
self.request_metadata = request_metadata # 保存元数据以便传递
details = {"provider": provider_name} if provider_name else {}
details.update(kwargs)
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
error_type="provider_error",
message=message,
details=details,
)
class ProviderNotAvailableException(ProviderException):
"""提供商不可用"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
)
class ProviderTimeoutException(ProviderException):
"""提供商请求超时"""
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
super().__init__(
message=f"提供商 '{provider_name}' 请求超时({timeout}秒)",
provider_name=provider_name,
request_metadata=request_metadata,
timeout=timeout,
)
class ProviderAuthException(ProviderException):
"""提供商认证失败"""
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
super().__init__(
message=f"提供商 '{provider_name}' 认证失败请检查API密钥",
provider_name=provider_name,
request_metadata=request_metadata,
)
class ProviderRateLimitException(ProviderException):
"""提供商限流"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
request_metadata: Optional[Any] = None,
response_headers: Optional[Dict[str, str]] = None, # 添加响应头
retry_after: Optional[int] = None, # 添加重试时间
):
self.response_headers = response_headers or {} # 保存响应头
self.retry_after = retry_after # 保存重试时间
super().__init__(
message=message, provider_name=provider_name, request_metadata=request_metadata
)
class QuotaExceededException(ProxyException):
"""配额超限"""
def __init__(self, quota_type: str = "tokens", remaining: Optional[float] = None):
message = f"{quota_type}配额已用尽"
if remaining is not None:
message += f"(剩余: {remaining}"
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
error_type="quota_exceeded",
message=message,
details={"quota_type": quota_type, "remaining": remaining},
)
class RateLimitException(ProxyException):
"""速率限制"""
def __init__(self, limit: int, window: str = "minute"):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
error_type="rate_limit",
message=f"请求过于频繁,限制为每{window} {limit}",
details={"limit": limit, "window": window},
)
class ConcurrencyLimitError(ProxyException):
"""并发限制异常"""
def __init__(
self, message: str, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
):
details = {}
if endpoint_id:
details["endpoint_id"] = endpoint_id
if key_id:
details["key_id"] = key_id
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
error_type="concurrency_limit",
message=message,
details=details,
)
class ModelNotSupportedException(ProxyException):
"""模型不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None):
message = f"模型 '{model}' 不受支持"
if provider_name:
message = f"提供商 '{provider_name}' 不支持模型 '{model}'"
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
error_type="model_not_supported",
message=message,
details={"model": model, "provider": provider_name},
)
class StreamingNotSupportedException(ProxyException):
"""流式请求不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None):
if provider_name:
message = f"模型 '{model}' 在提供商 '{provider_name}' 上不支持流式请求"
else:
message = f"模型 '{model}' 不支持流式请求"
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
error_type="streaming_not_supported",
message=message,
details={"model": model, "provider": provider_name},
)
class InvalidRequestException(ProxyException):
"""无效请求"""
def __init__(self, message: str, field: Optional[str] = None):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
error_type="invalid_request",
message=message,
details={"field": field} if field else {},
)
class NotFoundException(ProxyException):
"""资源未找到"""
def __init__(self, message: str, resource_type: Optional[str] = None):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
error_type="not_found",
message=message,
details={"resource_type": resource_type} if resource_type else {},
)
class ForbiddenException(ProxyException):
"""权限不足"""
def __init__(self, message: str, required_role: Optional[str] = None):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
error_type="forbidden",
message=message,
details={"required_role": required_role} if required_role else {},
)
class DecryptionException(ProxyException):
"""解密失败异常 - 已知的配置问题,不需要打印堆栈"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_type="decryption_error",
message=message,
details=details or {},
)
class JSONParseException(ProviderException):
"""JSON解析错误"""
def __init__(
self,
provider_name: str,
original_error: str,
response_content: Optional[str] = None,
content_type: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
details = {
"original_error": original_error,
"content_type": content_type,
}
if response_content and len(response_content) > 500:
# 截断长内容,但保留头尾
details["response_preview"] = f"{response_content[:200]}...{response_content[-200:]}"
elif response_content:
details["response_content"] = response_content
super().__init__(
message=f"提供商 '{provider_name}' 返回了无效的JSON响应",
provider_name=provider_name,
request_metadata=request_metadata,
**details,
)
class EmptyStreamException(ProviderException):
"""流式响应为空异常 - 上游返回200但没有发送任何数据"""
def __init__(
self,
provider_name: str,
chunk_count: int = 0,
request_metadata: Optional[Any] = None,
):
super().__init__(
message=f"提供商 '{provider_name}' 返回了空的流式响应status=200 但无数据)",
provider_name=provider_name,
request_metadata=request_metadata,
chunk_count=chunk_count,
)
class EmbeddedErrorException(ProviderException):
"""响应体内嵌套错误异常 - HTTP 状态码正常但响应体包含错误信息
用于处理某些 Provider如 Gemini返回 HTTP 200 但在响应体中包含错误的情况。
这类错误需要触发重试逻辑。
"""
def __init__(
self,
provider_name: str,
error_code: Optional[int] = None,
error_message: Optional[str] = None,
error_status: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
message = f"提供商 '{provider_name}' 返回了嵌套错误"
if error_code:
message += f" (code={error_code})"
if error_message:
message += f": {error_message}"
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
error_code=error_code,
error_status=error_status,
)
self.error_code = error_code
self.error_message = error_message
self.error_status = error_status
class UpstreamClientException(ProxyException):
"""上游返回的客户端错误异常 - HTTP 4xx 错误,不应该重试
用于处理上游 Provider 返回的客户端错误(如图片处理失败、无效请求等)。
这类错误是由用户请求本身的问题导致的,换 Provider 也无济于事,不应该重试。
常见场景:
- 图片处理失败(图片过大、格式不支持等)
- 请求参数无效
- 消息内容违规
"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
status_code: int = 400,
error_type: Optional[str] = None,
upstream_error: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
self.upstream_error = upstream_error
self.request_metadata = request_metadata
details = {}
if provider_name:
details["provider"] = provider_name
if error_type:
details["upstream_error_type"] = error_type
if upstream_error:
details["upstream_error"] = upstream_error
super().__init__(
status_code=status_code,
error_type="upstream_client_error",
message=message,
details=details,
)
class ErrorResponse:
"""统一的错误响应格式化器"""
@staticmethod
def create(
error_type: str,
message: str,
status_code: int = 500,
details: Optional[Dict[str, Any]] = None,
) -> JSONResponse:
"""创建标准错误响应"""
error_body = {"error": {"type": error_type, "message": message}}
if details:
error_body["error"]["details"] = details
# 记录错误日志
logger.error(f"Error response: {error_type} - {message}")
return JSONResponse(status_code=status_code, content=error_body)
@staticmethod
def from_exception(e: Exception) -> JSONResponse:
"""
从异常创建错误响应
安全说明:
- 生产环境只返回错误 ID不暴露详细信息
- 开发环境返回完整错误信息用于调试
- 所有错误都记录到日志,通过错误 ID 关联
"""
if isinstance(e, ProxyException):
return ErrorResponse.create(
error_type=e.error_type,
message=e.message,
status_code=e.status_code,
details=e.details,
)
elif isinstance(e, HTTPException):
return ErrorResponse.create(
error_type="http_error", message=str(e.detail), status_code=e.status_code
)
else:
# 未知异常,使用错误 ID 机制
error_id = str(uuid.uuid4())[:8] # 短 ID便于用户报告
error_type_name = type(e).__name__
error_message = str(e)
# 始终记录完整错误到日志
logger.error(f"[{error_id}] Unexpected error: {error_type_name}: {error_message}")
# 根据环境决定返回的详细程度
is_development = config.environment in ("development", "test", "testing")
if is_development:
# 开发环境:返回完整错误信息
return ErrorResponse.create(
error_type="internal_error",
message=f"内部服务器错误: {error_type_name}: {error_message}",
status_code=500,
details={
"error_id": error_id,
"error_type": error_type_name,
"error": error_message,
"traceback": traceback.format_exc().split("\n"),
},
)
else:
# 生产环境:只返回错误 ID
return ErrorResponse.create(
error_type="internal_error",
message="内部服务器错误",
status_code=500,
details={
"error_id": error_id,
"support_info": "请联系管理员并提供此错误 ID",
},
)
@staticmethod
def provider_error(provider_name: str, error: Exception) -> JSONResponse:
"""提供商错误响应 - 基于异常类型判断"""
# 基于异常类型判断,更可靠
if isinstance(error, (asyncio.TimeoutError, httpx.TimeoutException)):
return ErrorResponse.from_exception(ProviderTimeoutException(provider_name, 60))
elif isinstance(error, (httpx.HTTPStatusError,)):
if error.response.status_code == 401:
return ErrorResponse.from_exception(ProviderAuthException(provider_name))
elif error.response.status_code == 429:
return ErrorResponse.from_exception(
ProviderRateLimitException(
message=f"提供商 '{provider_name}' 速率限制",
provider_name=provider_name,
)
)
elif isinstance(error, (httpx.ConnectError, httpx.NetworkError)):
return ErrorResponse.create(
error_type="provider_connection_error",
message=f"无法连接到提供商 {provider_name}",
status_code=503,
details={"provider": provider_name, "error": "Connection failed"},
)
# 如果异常类型无法判断,再通过字符串匹配作为备用
elif "auth" in str(error).lower() or "401" in str(error):
return ErrorResponse.from_exception(ProviderAuthException(provider_name))
elif "rate limit" in str(error).lower() or "429" in str(error):
return ErrorResponse.from_exception(
ProviderRateLimitException(
message=f"提供商 '{provider_name}' 速率限制",
provider_name=provider_name,
)
)
else:
return ErrorResponse.create(
error_type="provider_error",
message=f"提供商请求失败: {str(error)}",
status_code=503,
details={
"provider": provider_name,
"error": str(error),
"error_type": type(error).__name__,
},
)
class ExceptionHandlers:
"""FastAPI异常处理器"""
@staticmethod
async def handle_proxy_exception(request, exc: ProxyException):
"""处理代理异常"""
return ErrorResponse.from_exception(exc)
@staticmethod
async def handle_http_exception(request, exc: HTTPException):
"""处理HTTP异常"""
return ErrorResponse.from_exception(exc)
@staticmethod
async def handle_generic_exception(request, exc: Exception):
"""处理通用异常 - 集成韧性管理"""
# 首先检查是否为HTTPException如果是则委托给HTTP异常处理器
if isinstance(exc, HTTPException):
return await ExceptionHandlers.handle_http_exception(request, exc)
# 获取请求信息用于上下文
request_info = {
"method": request.method,
"url": str(request.url),
"client_ip": (
getattr(request.client, "host", "unknown")
if hasattr(request, "client")
else "unknown"
),
"user_agent": request.headers.get("user-agent", "unknown"),
}
# 使用韧性管理器处理错误
rm = get_resilience_manager()
if rm:
try:
error_result = rm.handle_error(
error=exc,
context=request_info,
operation=f"{request.method} {request.url.path}",
)
# 根据错误处理结果返回适当的响应
if error_result.get("severity") and error_result["severity"].value == "critical":
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
elif error_result.get("severity") and error_result["severity"].value == "high":
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return ErrorResponse.create(
status_code=status_code,
error_type="system_error",
message=error_result.get("user_message", "系统遇到未知错误"),
details={
"error_id": error_result.get("error_id"),
"recovery_info": "请稍后重试,如问题持续请联系管理员",
},
)
except Exception as resilience_error:
# 如果韧性管理器本身出错,降级到基本处理
logger.exception("韧性管理器处理异常时出错")
# 降级处理:基本的异常响应
return ErrorResponse.from_exception(exc)

View File

@@ -0,0 +1,247 @@
"""
Key 能力系统
能力类型:
1. 互斥能力 (EXCLUSIVE): 需要时选有的,不需要时选没有的(如 cache_1h
2. 兼容能力 (COMPATIBLE): 需要时选有的,不需要时都可选(如 context_1m
配置模式:
1. user_configurable: 用户可配置(模型级 + Key级强制
2. auto_detect: 自动检测(请求失败后升级)
3. request_param: 从请求参数检测
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Tuple
class CapabilityMatchMode(Enum):
"""能力匹配模式"""
EXCLUSIVE = "exclusive" # 互斥:需要时选有的,不需要时选没有的
COMPATIBLE = "compatible" # 兼容:需要时选有的,不需要时都可选
class CapabilityConfigMode(Enum):
"""能力配置模式"""
USER_CONFIGURABLE = "user_configurable" # 用户可配置(模型级 + Key级强制
AUTO_DETECT = "auto_detect" # 自动检测(请求失败后升级)
REQUEST_PARAM = "request_param" # 从请求参数检测
@dataclass
class CapabilityDefinition:
"""能力定义"""
name: str
display_name: str
description: str
match_mode: CapabilityMatchMode
config_mode: CapabilityConfigMode
short_name: str = "" # 简短展示名称(用于列表等紧凑场景)
error_patterns: List[str] = field(default_factory=list) # 错误检测关键词组
# ============ 能力注册表 ============
_capabilities: Dict[str, CapabilityDefinition] = {}
def register_capability(
name: str,
display_name: str,
description: str,
match_mode: CapabilityMatchMode,
config_mode: CapabilityConfigMode,
short_name: str = "",
error_patterns: Optional[List[str]] = None,
) -> CapabilityDefinition:
"""注册能力"""
cap = CapabilityDefinition(
name=name,
display_name=display_name,
description=description,
match_mode=match_mode,
config_mode=config_mode,
short_name=short_name or display_name, # 默认使用 display_name
error_patterns=error_patterns or [],
)
_capabilities[name] = cap
return cap
def get_capability(name: str) -> Optional[CapabilityDefinition]:
"""获取能力定义"""
return _capabilities.get(name)
def get_all_capabilities() -> List[CapabilityDefinition]:
"""获取所有能力定义"""
return list(_capabilities.values())
def get_user_configurable_capabilities() -> List[CapabilityDefinition]:
"""获取用户可配置的能力列表"""
return [c for c in _capabilities.values() if c.config_mode == CapabilityConfigMode.USER_CONFIGURABLE]
# ============ 能力匹配检查 ============
def check_capability_match(
key_capabilities: Optional[Dict[str, bool]],
requirements: Optional[Dict[str, bool]],
) -> Tuple[bool, Optional[str]]:
"""
检查 Key 能力是否满足需求
匹配逻辑:
1. EXCLUSIVE互斥能力
- 请求需要且 Key 有 → 通过
- 请求需要但 Key 没有 → 拒绝
- 请求不需要但 Key 有 → 拒绝(避免浪费高价资源)
- 请求不需要且 Key 没有 → 通过
- 请求未声明但 Key 有 → 拒绝(关键:未声明等同于不需要)
2. COMPATIBLE兼容能力
- 请求需要且 Key 有 → 通过
- 请求需要但 Key 没有 → 拒绝
- 请求不需要/未声明且 Key 有 → 通过(无额外成本,不浪费)
- 请求不需要/未声明且 Key 没有 → 通过
Args:
key_capabilities: Key 拥有的能力 {"cache_1h": True, ...}
requirements: 请求需要的能力 {"cache_1h": True, "context_1m": False}
Returns:
(is_match, skip_reason) - 是否匹配及跳过原因
"""
key_caps = key_capabilities or {}
reqs = requirements or {}
# 第一步:检查请求声明的需求
for cap_name, is_required in reqs.items():
cap_def = _capabilities.get(cap_name)
if not cap_def:
continue
key_has_cap = key_caps.get(cap_name, False)
if cap_def.match_mode == CapabilityMatchMode.EXCLUSIVE:
if is_required and not key_has_cap:
return False, f"需要{cap_def.display_name}但 Key 不支持"
if not is_required and key_has_cap:
return False, f"不需要{cap_def.display_name}(避免浪费高价资源)"
elif cap_def.match_mode == CapabilityMatchMode.COMPATIBLE:
if is_required and not key_has_cap:
return False, f"需要{cap_def.display_name}但 Key 不支持"
# 第二步:检查 Key 拥有的 EXCLUSIVE 能力是否被请求需要
# 如果 Key 有某个 EXCLUSIVE 能力,但请求没有声明需要,应该跳过这个 Key
for cap_name, key_has_cap in key_caps.items():
if not key_has_cap:
continue
cap_def = _capabilities.get(cap_name)
if not cap_def:
continue
if cap_def.match_mode == CapabilityMatchMode.EXCLUSIVE:
# 如果请求没有声明需要这个 EXCLUSIVE 能力,视为不需要
if cap_name not in reqs:
return False, f"不需要{cap_def.display_name}(避免浪费高价资源)"
return True, None
def _match_error_patterns(error_msg: str, patterns: List[str]) -> bool:
"""检查错误信息是否匹配模式(所有关键词都要出现)"""
if not patterns:
return False
msg_lower = error_msg.lower()
return all(p.lower() in msg_lower for p in patterns)
def detect_capability_upgrade_from_error(
error_msg: str,
current_requirements: Optional[Dict[str, bool]] = None,
) -> Optional[str]:
"""
从错误信息检测是否需要升级某能力
Args:
error_msg: 错误信息
current_requirements: 当前已有的能力需求
Returns:
需要升级的能力名称,如果不需要升级则返回 None
"""
current_reqs = current_requirements or {}
for cap in _capabilities.values():
if not current_reqs.get(cap.name) and cap.error_patterns:
if _match_error_patterns(error_msg, cap.error_patterns):
return cap.name
return None
# ============ 兼容性别名 ============
# 保留旧 API 兼容
get_capability_definition = get_capability
class _CapabilityDefinitionsProxy:
"""CAPABILITY_DEFINITIONS 代理,提供字典式访问(兼容旧代码)"""
def get(self, name: str) -> Optional[CapabilityDefinition]:
return _capabilities.get(name)
def __getitem__(self, name: str) -> CapabilityDefinition:
result = _capabilities.get(name)
if result is None:
raise KeyError(name)
return result
def __contains__(self, name: str) -> bool:
return name in _capabilities
def values(self) -> List[CapabilityDefinition]:
return list(_capabilities.values())
def items(self) -> List[Tuple[str, CapabilityDefinition]]:
return list(_capabilities.items())
CAPABILITY_DEFINITIONS = _CapabilityDefinitionsProxy()
# ============ 兼容旧的插件基类(逐步废弃) ============
CapabilityPlugin = CapabilityDefinition # 类型别名,兼容旧代码
# ============ 注册内置能力 ============
register_capability(
name="cache_1h",
display_name="1 小时缓存",
description="使用 1 小时缓存 TTL价格更高适合长对话",
match_mode=CapabilityMatchMode.EXCLUSIVE,
config_mode=CapabilityConfigMode.USER_CONFIGURABLE,
short_name="1h缓存",
)
register_capability(
name="context_1m",
display_name="CLI 1M 上下文",
description="支持 1M tokens 上下文窗口",
match_mode=CapabilityMatchMode.COMPATIBLE,
config_mode=CapabilityConfigMode.REQUEST_PARAM,
short_name="CLI 1M",
error_patterns=["context", "token", "length", "exceed"], # 上下文超限错误
)

135
src/core/logger.py Normal file
View File

@@ -0,0 +1,135 @@
"""
统一日志系统 - 基于 loguru
日志级别策略:
- DEBUG: 开发调试,详细执行流程、变量值、缓存操作
- INFO: 生产环境,关键业务操作、状态变更、请求处理
- WARNING: 潜在问题、降级处理、资源警告
- ERROR: 异常错误、需要关注的故障
输出策略:
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
- 文件: 始终保存 DEBUG 级别保留30天每日轮转
使用方式:
from src.core.logger import logger
logger.info("消息")
logger.debug("调试信息")
logger.warning("警告")
logger.error("错误")
logger.exception("异常,带堆栈")
"""
import logging
import os
import sys
from pathlib import Path
from loguru import logger
# ============================================================================
# 环境检测
# ============================================================================
IS_DOCKER = (
os.path.exists("/.dockerenv")
or os.environ.get("DOCKER_CONTAINER", "false").lower() == "true"
)
# 日志级别: 默认开发环境 DEBUG, 生产环境 INFO
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG" if not IS_DOCKER else "INFO").upper()
# 是否禁用文件日志 (用于测试或特殊场景)
DISABLE_FILE_LOG = os.getenv("LOG_DISABLE_FILE", "false").lower() == "true"
# 项目根目录
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
# ============================================================================
# 日志格式定义
# ============================================================================
CONSOLE_FORMAT_DEV = (
"<green>{time:HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{message}</cyan>"
)
CONSOLE_FORMAT_PROD = "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}"
FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"
# ============================================================================
# 日志配置
# ============================================================================
logger.remove()
def _log_filter(record):
return "watchfiles" not in record["name"]
if IS_DOCKER:
logger.add(
sys.stdout,
format=CONSOLE_FORMAT_PROD,
level=LOG_LEVEL,
filter=_log_filter,
colorize=False,
)
else:
logger.add(
sys.stdout,
format=CONSOLE_FORMAT_DEV,
level=LOG_LEVEL,
filter=_log_filter,
colorize=True,
)
if not DISABLE_FILE_LOG:
log_dir = PROJECT_ROOT / "logs"
log_dir.mkdir(exist_ok=True)
# 主日志文件 - 所有级别
logger.add(
log_dir / "app.log",
format=FILE_FORMAT,
level="DEBUG",
filter=_log_filter,
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
)
# 错误日志文件 - 仅 ERROR 及以上
logger.add(
log_dir / "error.log",
format=FILE_FORMAT,
level="ERROR",
filter=_log_filter,
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
)
# ============================================================================
# 禁用第三方库噪音日志
# ============================================================================
logging.getLogger("watchfiles").setLevel(logging.ERROR)
logging.getLogger("watchfiles.main").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
# ============================================================================
# 导出
# ============================================================================
__all__ = ["logger"]

46
src/core/metrics.py Normal file
View File

@@ -0,0 +1,46 @@
"""
Prometheus metrics for monitoring
"""
from prometheus_client import Counter, Gauge, Histogram
# 并发槽位占用时长分布
concurrency_slot_duration_seconds = Histogram(
"concurrency_slot_duration_seconds",
"Duration of concurrency slot occupation in seconds",
["key_id", "exception"],
buckets=[0.1, 0.5, 1, 5, 10, 30, 60, 120, 300, 600], # 0.1s 到 10 分钟
)
# 并发槽位释放计数
concurrency_slot_release_total = Counter(
"concurrency_slot_release_total",
"Total number of concurrency slot releases",
["key_id", "exception"],
)
# 当前并发槽位使用数
concurrency_slots_in_use = Gauge(
"concurrency_slots_in_use", "Current number of concurrency slots in use", ["key_id"]
)
# 流式请求时长分布
streaming_request_duration_seconds = Histogram(
"streaming_request_duration_seconds",
"Duration of streaming requests in seconds",
["key_id", "status"],
buckets=[1, 5, 10, 30, 60, 120, 300, 600, 1800], # 1s 到 30 分钟
)
# 请求总数(按类型)
request_total = Counter(
"request_total",
"Total number of requests",
["type", "status"], # type values: streaming/non-streaming, status: success/error
)
# 健康监控相关
health_open_circuits = Gauge(
"health_open_circuits",
"Number of provider keys currently in circuit breaker open state",
)

View File

@@ -0,0 +1,106 @@
"""
优化工具类 - 包含Token计数和响应头管理
"""
from typing import Any, Dict, Optional
import tiktoken
class TokenCounter:
"""
改进的Token计数器
支持多种模型的准确计数
"""
# 模型到编码器的映射
MODEL_TO_ENCODING = {
"gpt-4": "cl100k_base",
"gpt-3.5-turbo": "cl100k_base",
"claude-3": "cl100k_base", # Claude使用类似的tokenizer
"claude-2": "cl100k_base",
}
def __init__(self):
self._encodings = {}
self._default_encoding = None
def _get_encoding(self, model: str):
"""获取模型对应的编码器"""
# 标准化模型名称
model_base = model.lower().split("-")[0]
if model_base not in self._encodings:
encoding_name = self.MODEL_TO_ENCODING.get(model_base, "cl100k_base") # 默认编码器
try:
self._encodings[model_base] = tiktoken.get_encoding(encoding_name)
except Exception:
# 如果失败,使用默认编码器
if not self._default_encoding:
self._default_encoding = tiktoken.get_encoding("cl100k_base")
self._encodings[model_base] = self._default_encoding
return self._encodings[model_base]
def count_tokens(self, text: str, model: str = "claude-3") -> int:
"""
精确计算文本的token数量
"""
if not text:
return 0
try:
encoding = self._get_encoding(model)
return len(encoding.encode(text))
except Exception:
# 降级到简单估算
return len(text) // 4
def count_messages_tokens(self, messages: list, model: str = "claude-3") -> int:
"""
计算消息列表的总token数
"""
total = 0
for message in messages:
if isinstance(message, dict):
# 计算角色标记
total += 4 # 角色和分隔符的开销
# 计算内容
content = message.get("content", "")
if isinstance(content, str):
total += self.count_tokens(content, model)
elif isinstance(content, list):
# 处理多模态内容
for item in content:
if isinstance(item, dict) and "text" in item:
total += self.count_tokens(item["text"], model)
return total
def estimate_response_tokens(self, response: Any, model: str = "claude-3") -> int:
"""
估算响应的token数量
"""
if isinstance(response, dict):
# 尝试从响应中提取内容
if "content" in response:
content = response["content"]
if isinstance(content, list):
text = " ".join(
item.get("text", "") for item in content if isinstance(item, dict)
)
else:
text = str(content)
return self.count_tokens(text, model)
elif "choices" in response:
# OpenAI格式
total = 0
for choice in response.get("choices", []):
message = choice.get("message", {})
content = message.get("content", "")
total += self.count_tokens(content, model)
return total
# 降级到简单估算
return len(str(response)) // 4

205
src/core/provider_health.py Normal file
View File

@@ -0,0 +1,205 @@
"""
提供商健康度管理
基于简单的失败计数和优先级调整
"""
import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, Optional
class ProviderHealthTracker:
"""
追踪提供商的健康状态
根据失败率动态调整优先级
"""
def __init__(
self,
failure_window: int = 300, # 5分钟时间窗口
failure_threshold: int = 3, # 3次失败降低优先级
recovery_time: int = 600, # 10分钟后重置
):
self.failure_window = failure_window
self.failure_threshold = failure_threshold
self.recovery_time = recovery_time
# 存储每个提供商的失败记录
self.failures: Dict[str, list] = defaultdict(list)
# 存储每个提供商的成功记录
self.successes: Dict[str, list] = defaultdict(list)
# 存储优先级调整
self.priority_adjustments: Dict[str, int] = {}
def record_success(self, provider_name: str):
"""记录成功的请求"""
current_time = time.time()
# 记录成功时间
self.successes[provider_name].append(current_time)
# 清理旧记录
self._cleanup_old_records(provider_name, current_time)
# 如果连续成功,可以恢复优先级
if len(self.successes[provider_name]) >= 5:
if self.priority_adjustments.get(provider_name, 0) < 0:
self.priority_adjustments[provider_name] += 1
def record_failure(self, provider_name: str):
"""记录失败的请求"""
current_time = time.time()
# 记录失败时间
self.failures[provider_name].append(current_time)
# 清理旧记录
self._cleanup_old_records(provider_name, current_time)
# 检查是否需要降低优先级
recent_failures = len(self.failures[provider_name])
if recent_failures >= self.failure_threshold:
# 降低优先级
current_adjustment = self.priority_adjustments.get(provider_name, 0)
self.priority_adjustments[provider_name] = current_adjustment - 1
def get_priority_adjustment(self, provider_name: str) -> int:
"""
获取优先级调整值
负数表示降低优先级,正数表示提高优先级
"""
return self.priority_adjustments.get(provider_name, 0)
def get_health_status(self, provider_name: str) -> Dict:
"""
获取提供商的健康状态
"""
current_time = time.time()
self._cleanup_old_records(provider_name, current_time)
recent_failures = len(self.failures[provider_name])
recent_successes = len(self.successes[provider_name])
total_requests = recent_failures + recent_successes
failure_rate = recent_failures / total_requests if total_requests > 0 else 0
return {
"provider": provider_name,
"recent_failures": recent_failures,
"recent_successes": recent_successes,
"failure_rate": failure_rate,
"priority_adjustment": self.get_priority_adjustment(provider_name),
"status": self._get_status_label(failure_rate, recent_failures),
}
def _cleanup_old_records(self, provider_name: str, current_time: float):
"""清理超出时间窗口的记录"""
# 清理失败记录
self.failures[provider_name] = [
t for t in self.failures[provider_name] if current_time - t < self.failure_window
]
# 清理成功记录
self.successes[provider_name] = [
t for t in self.successes[provider_name] if current_time - t < self.failure_window
]
# 如果很久没有失败,重置优先级调整
if not self.failures[provider_name] and self.priority_adjustments.get(provider_name, 0) < 0:
# 检查恢复时间
if all(current_time - t > self.recovery_time for t in self.successes[provider_name]):
self.priority_adjustments[provider_name] = 0
def _get_status_label(self, failure_rate: float, recent_failures: int) -> str:
"""根据失败率返回状态标签"""
if recent_failures >= self.failure_threshold:
return "degraded" # 降级
elif failure_rate > 0.5:
return "unstable" # 不稳定
elif failure_rate > 0.1:
return "warning" # 警告
else:
return "healthy" # 健康
def should_use_provider(self, provider_name: str) -> bool:
"""
判断是否应该使用该提供商
简单的策略:如果优先级调整低于-3暂时不使用
"""
adjustment = self.get_priority_adjustment(provider_name)
return adjustment > -3
def reset_provider_health(self, provider_name: str):
"""重置提供商的健康状态(管理员手动操作)"""
self.failures[provider_name] = []
self.successes[provider_name] = []
self.priority_adjustments[provider_name] = 0
class SimpleProviderSelector:
"""
简单的提供商选择器
基于优先级和健康状态
"""
def __init__(self, health_tracker: ProviderHealthTracker):
self.health_tracker = health_tracker
def select_provider(self, providers: list, specified_provider: Optional[str] = None):
"""
选择提供商
Args:
providers: 可用提供商列表(已按基础优先级排序)
specified_provider: 用户指定的提供商
Returns:
选中的提供商
"""
# 如果用户指定了提供商,直接使用(不管健康状态)
if specified_provider:
return next((p for p in providers if p.name == specified_provider), None)
# 否则,根据优先级和健康状态选择
# 对提供商列表进行动态排序
sorted_providers = sorted(
providers,
key=lambda p: (
p.priority + self.health_tracker.get_priority_adjustment(p.name),
-p.id, # 相同优先级时使用ID作为次要排序
),
reverse=True, # 优先级高的在前
)
# 选择第一个健康的提供商
for provider in sorted_providers:
if self.health_tracker.should_use_provider(provider.name):
return provider
# 如果都不健康,还是返回第一个(降级策略)
return sorted_providers[0] if sorted_providers else None
def get_provider_rankings(self, providers: list) -> list:
"""
获取提供商的当前排名(用于调试和监控)
"""
rankings = []
for provider in providers:
health_status = self.health_tracker.get_health_status(provider.name)
effective_priority = provider.priority + health_status["priority_adjustment"]
rankings.append(
{
"name": provider.name,
"base_priority": provider.priority,
"adjustment": health_status["priority_adjustment"],
"effective_priority": effective_priority,
"status": health_status["status"],
"failure_rate": health_status["failure_rate"],
}
)
# 按有效优先级排序
rankings.sort(key=lambda x: x["effective_priority"], reverse=True)
return rankings

428
src/core/resilience.py Normal file
View File

@@ -0,0 +1,428 @@
"""
系统韧性和风险管控模块
提供全局的错误处理、自动恢复、降级策略和用户友好的错误体验
"""
import asyncio
import functools
import threading
import time
import traceback
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Type, Union
from ..core.exceptions import ProxyException
from src.core.logger import logger
class ErrorSeverity(Enum):
"""错误严重程度"""
LOW = "low" # 低级错误,不影响核心功能
MEDIUM = "medium" # 中级错误,影响部分功能
HIGH = "high" # 高级错误,影响主要功能
CRITICAL = "critical" # 严重错误,影响系统可用性
class RecoveryStrategy(Enum):
"""恢复策略"""
RETRY = "retry" # 重试
FALLBACK = "fallback" # 降级
CIRCUIT_BREAKER = "circuit_breaker" # 熔断
GRACEFUL_DEGRADE = "graceful_degrade" # 优雅降级
USER_NOTIFY = "user_notify" # 通知用户
class ErrorPattern:
"""错误模式定义"""
def __init__(
self,
error_types: List[Type[Exception]],
severity: ErrorSeverity,
recovery_strategy: RecoveryStrategy,
user_message: str,
auto_recover: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
circuit_threshold: int = 5,
):
self.error_types = error_types
self.severity = severity
self.recovery_strategy = recovery_strategy
self.user_message = user_message
self.auto_recover = auto_recover
self.max_retries = max_retries
self.retry_delay = retry_delay
self.circuit_threshold = circuit_threshold
class CircuitBreaker:
"""熔断器"""
def __init__(self, failure_threshold: int = 5, timeout: int = 60):
self.failure_threshold = failure_threshold
self.timeout = timeout
self.failure_count = 0
self.last_failure_time = None
self.state = "closed" # closed, open, half-open
self._lock = threading.Lock()
def call(self, func: Callable, *args, **kwargs):
"""执行函数调用,应用熔断逻辑"""
with self._lock:
if self.state == "open":
if self._should_attempt_reset():
self.state = "half-open"
else:
raise Exception("服务暂时不可用,请稍后重试")
try:
result = func(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure()
raise
def _should_attempt_reset(self) -> bool:
"""检查是否应该尝试重置熔断器"""
if self.last_failure_time is None:
return True
return time.time() - self.last_failure_time >= self.timeout
def _on_success(self):
"""成功时重置计数器"""
self.failure_count = 0
self.state = "closed"
def _on_failure(self):
"""失败时增加计数器"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = "open"
class ResilienceManager:
"""系统韧性管理器"""
def __init__(self):
self.error_patterns: List[ErrorPattern] = []
self.circuit_breakers: Dict[str, CircuitBreaker] = {}
self.error_stats: Dict[str, int] = {}
self.last_errors: List[Dict[str, Any]] = []
self._setup_default_patterns()
def _setup_default_patterns(self):
"""设置默认错误处理模式"""
# 数据库连接错误 - 只捕获特定的数据库相关异常
try:
from sqlalchemy.exc import (
DatabaseError,
DisconnectionError,
OperationalError,
StatementError,
)
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
db_exceptions = [
OperationalError,
DisconnectionError,
SQLTimeoutError,
StatementError,
DatabaseError,
]
except ImportError:
# 如果SQLAlchemy不可用使用通用异常类型
db_exceptions = [ConnectionError, OSError]
self.add_error_pattern(
ErrorPattern(
error_types=db_exceptions,
severity=ErrorSeverity.HIGH,
recovery_strategy=RecoveryStrategy.RETRY,
user_message="数据库连接异常,正在重试...",
max_retries=3,
retry_delay=1.0,
)
)
# 认证相关错误 - 只捕获特定的认证异常
try:
from ..core.exceptions import ForbiddenException, ProviderAuthException
auth_exceptions = [ProviderAuthException, ForbiddenException]
except ImportError:
# 如果无法导入特定异常,使用更保守的方式(不使用通用异常)
auth_exceptions = []
if auth_exceptions:
self.add_error_pattern(
ErrorPattern(
error_types=auth_exceptions,
severity=ErrorSeverity.MEDIUM,
recovery_strategy=RecoveryStrategy.USER_NOTIFY,
user_message="认证失败请检查API密钥或重新登录",
auto_recover=False,
)
)
# 网络请求错误
self.add_error_pattern(
ErrorPattern(
error_types=[ConnectionError, TimeoutError],
severity=ErrorSeverity.MEDIUM,
recovery_strategy=RecoveryStrategy.FALLBACK,
user_message="网络连接异常,正在尝试备用方案...",
max_retries=2,
)
)
def add_error_pattern(self, pattern: ErrorPattern):
"""添加错误处理模式"""
self.error_patterns.append(pattern)
def get_circuit_breaker(self, key: str) -> CircuitBreaker:
"""获取或创建熔断器"""
if key not in self.circuit_breakers:
self.circuit_breakers[key] = CircuitBreaker()
return self.circuit_breakers[key]
def handle_error(
self, error: Exception, context: Dict[str, Any] = None, operation: str = "unknown"
) -> Dict[str, Any]:
"""处理错误并返回处理结果"""
error_id = str(uuid.uuid4())[:8]
context = context or {}
# 记录错误
error_info = {
"error_id": error_id,
"error_type": type(error).__name__,
"error_message": str(error),
"operation": operation,
"context": context,
"timestamp": datetime.now(timezone.utc),
"traceback": traceback.format_exc(),
}
self.last_errors.append(error_info)
# 只保留最近100个错误
if len(self.last_errors) > 100:
self.last_errors.pop(0)
# 更新错误统计
error_key = f"{type(error).__name__}:{operation}"
self.error_stats[error_key] = self.error_stats.get(error_key, 0) + 1
# 查找匹配的错误处理模式
pattern = self._find_matching_pattern(error)
if pattern:
logger.error(f"错误处理 [{error_id}]: {pattern.user_message}")
return {
"error_id": error_id,
"severity": pattern.severity,
"recovery_strategy": pattern.recovery_strategy,
"user_message": pattern.user_message,
"auto_recover": pattern.auto_recover,
"pattern": pattern,
}
else:
# 未匹配的错误,使用默认处理
logger.error(f"未知错误 [{error_id}]: {str(error)}")
return {
"error_id": error_id,
"severity": ErrorSeverity.MEDIUM,
"recovery_strategy": RecoveryStrategy.USER_NOTIFY,
"user_message": "系统遇到未知错误,请稍后重试或联系管理员",
"auto_recover": False,
"pattern": None,
}
def _find_matching_pattern(self, error: Exception) -> Optional[ErrorPattern]:
"""查找匹配的错误处理模式"""
for pattern in self.error_patterns:
if any(isinstance(error, error_type) for error_type in pattern.error_types):
return pattern
return None
def get_error_stats(self) -> Dict[str, Any]:
"""获取错误统计"""
return {
"total_errors": sum(self.error_stats.values()),
"error_breakdown": self.error_stats.copy(),
"recent_errors": len(self.last_errors),
"circuit_breakers": {
key: {"state": cb.state, "failure_count": cb.failure_count}
for key, cb in self.circuit_breakers.items()
},
}
# 全局韧性管理器实例
resilience_manager = ResilienceManager()
def resilient_operation(
operation_name: str = None,
max_retries: int = None,
retry_delay: float = None,
circuit_breaker_key: str = None,
context: Dict[str, Any] = None,
):
"""
韧性操作装饰器
自动处理重试、熔断、错误记录等
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
op_name = operation_name or f"{func.__module__}.{func.__name__}"
retries = max_retries or 3
delay = retry_delay or 1.0
last_error = None
for attempt in range(retries + 1):
try:
# 如果指定了熔断器,使用熔断逻辑
if circuit_breaker_key:
cb = resilience_manager.get_circuit_breaker(circuit_breaker_key)
if asyncio.iscoroutinefunction(func):
return await cb.call(func, *args, **kwargs)
else:
return cb.call(func, *args, **kwargs)
else:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# 处理错误
error_result = resilience_manager.handle_error(
error=e,
context={**(context or {}), "attempt": attempt + 1, "max_retries": retries},
operation=op_name,
)
# 如果是最后一次尝试,或者不应该自动恢复,直接抛出
if attempt == retries or not error_result.get("auto_recover", True):
raise ProxyException(
status_code=500,
error_type="system_error",
message=error_result["user_message"],
details={
"error_id": error_result["error_id"],
"original_error": str(e),
},
)
# 等待后重试
if attempt < retries:
await asyncio.sleep(delay * (attempt + 1)) # 指数退避
# 这里不应该到达,但作为安全网
raise last_error
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
# 对于同步函数,创建异步包装器并运行
return asyncio.run(async_wrapper(*args, **kwargs))
# 根据函数类型返回对应的包装器
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
@asynccontextmanager
async def safe_operation(operation_name: str, context: Dict[str, Any] = None):
"""
安全操作上下文管理器
自动处理异常并提供用户友好的错误信息
"""
try:
yield
except Exception as e:
error_result = resilience_manager.handle_error(
error=e, context=context or {}, operation=operation_name
)
# 根据错误严重程度决定是否抛出异常
if error_result["severity"] in [ErrorSeverity.HIGH, ErrorSeverity.CRITICAL]:
raise ProxyException(
status_code=500,
error_type="system_error",
message=error_result["user_message"],
details={"error_id": error_result["error_id"]},
)
else:
# 记录警告但不中断操作
logger.warning(f"操作警告 [{error_result['error_id']}]: {error_result['user_message']}")
def graceful_degradation(fallback_func: Callable = None, fallback_value: Any = None):
"""
优雅降级装饰器
当主要功能失败时,自动切换到备用方案
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
logger.warning(f"主要功能失败,启用降级模式: {func.__name__}")
if fallback_func:
try:
if asyncio.iscoroutinefunction(fallback_func):
return await fallback_func(*args, **kwargs)
else:
return fallback_func(*args, **kwargs)
except Exception as fallback_error:
logger.exception(f"降级方案也失败了: {fallback_func.__name__}")
raise e # 抛出原始错误
else:
return fallback_value
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return lambda *args, **kwargs: asyncio.run(async_wrapper(*args, **kwargs))
return decorator
# 导出主要接口
__all__ = [
"resilience_manager",
"resilient_operation",
"safe_operation",
"graceful_degradation",
"ErrorSeverity",
"RecoveryStrategy",
"ErrorPattern",
]

181
src/core/validators.py Normal file
View File

@@ -0,0 +1,181 @@
"""
输入验证器
包含密码复杂度验证和其他输入验证
"""
import re
from typing import List, Optional
class PasswordValidator:
"""密码复杂度验证器"""
MIN_LENGTH = 6 # 降低到6位
MAX_LENGTH = 128
@classmethod
def validate(cls, password: str) -> tuple[bool, Optional[str]]:
"""
验证密码复杂度
要求:
- 长度至少6个字符
Args:
password: 待验证的密码
Returns:
(是否通过, 错误消息)
"""
if not password:
return False, "密码不能为空"
if len(password) < cls.MIN_LENGTH:
return False, f"密码长度至少为{cls.MIN_LENGTH}个字符"
if len(password) > cls.MAX_LENGTH:
return False, f"密码长度不能超过{cls.MAX_LENGTH}个字符"
# 简化密码复杂度要求 - 只检查长度
# 不再要求大小写字母、数字和特殊字符
# 检查常见弱密码
weak_passwords = [
"password123",
"admin123",
"12345678",
"qwerty123",
"password@123",
"admin@123",
"Password123!",
"Admin123!",
]
if password.lower() in [p.lower() for p in weak_passwords]:
return False, "密码过于简单,请使用更复杂的密码"
return True, None
@classmethod
def get_password_strength(cls, password: str) -> str:
"""
获取密码强度评级
Args:
password: 密码
Returns:
强度评级: 弱、中、强、非常强
"""
if not password:
return "无效"
score = 0
# 长度评分
if len(password) >= 8:
score += 1
if len(password) >= 12:
score += 1
if len(password) >= 16:
score += 1
# 字符类型评分
if re.search(r"[a-z]", password):
score += 1
if re.search(r"[A-Z]", password):
score += 1
if re.search(r"\d", password):
score += 1
if re.search(r'[!@#$%^&*()_+\-=\[\]{};:\'",.<>?/\\|`~]', password):
score += 2
# 额外复杂度评分
if re.search(r"[^\w\s]", password): # 非字母数字字符
score += 1
if score < 3:
return ""
elif score < 5:
return ""
elif score < 7:
return ""
else:
return "非常强"
class EmailValidator:
"""邮箱验证器"""
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
@classmethod
def validate(cls, email: str) -> tuple[bool, Optional[str]]:
"""
验证邮箱格式
Args:
email: 待验证的邮箱
Returns:
(是否通过, 错误消息)
"""
if not email:
return False, "邮箱不能为空"
if len(email) > 255:
return False, "邮箱长度不能超过255个字符"
if not cls.EMAIL_REGEX.match(email):
return False, "邮箱格式不正确"
return True, None
class UsernameValidator:
"""用户名验证器"""
MIN_LENGTH = 3
MAX_LENGTH = 30
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
@classmethod
def validate(cls, username: str) -> tuple[bool, Optional[str]]:
"""
验证用户名
Args:
username: 待验证的用户名
Returns:
(是否通过, 错误消息)
"""
if not username:
return False, "用户名不能为空"
if len(username) < cls.MIN_LENGTH:
return False, f"用户名长度至少为{cls.MIN_LENGTH}个字符"
if len(username) > cls.MAX_LENGTH:
return False, f"用户名长度不能超过{cls.MAX_LENGTH}个字符"
if not cls.USERNAME_REGEX.match(username):
return False, "用户名只能包含字母、数字、下划线和连字符"
# 检查保留用户名
reserved_names = [
"admin",
"root",
"system",
"api",
"test",
"demo",
"user",
"guest",
"bot",
"webhook",
"support",
]
if username.lower() in reserved_names:
return False, "该用户名为系统保留用户名"
return True, None