mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
Initial commit
This commit is contained in:
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
272
src/core/api_format_metadata.py
Normal file
272
src/core/api_format_metadata.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
|
||||
|
||||
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
|
||||
- api_format_metadata: 定义格式的元数据(别名、默认路径)
|
||||
- src/formats/: 定义格式的协议实现(解析、转换、验证)
|
||||
|
||||
使用方式:
|
||||
# 解析格式别名
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
|
||||
|
||||
# 获取格式协议
|
||||
from src.core.api_format_metadata import get_format_protocol
|
||||
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
from .enums import APIFormat
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiFormatDefinition:
|
||||
"""
|
||||
描述一个 API 格式的所有通用信息。
|
||||
|
||||
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
|
||||
- default_path: 上游默认请求路径(如 /v1/messages),可通过 Endpoint.custom_path 覆盖
|
||||
- path_prefix: 本站路径前缀(如 /claude, /openai),为空表示无前缀
|
||||
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
|
||||
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
|
||||
"""
|
||||
|
||||
api_format: APIFormat
|
||||
aliases: Sequence[str] = field(default_factory=tuple)
|
||||
default_path: str = "/" # 上游默认请求路径
|
||||
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
|
||||
auth_header: str = "Authorization"
|
||||
auth_type: str = "bearer" # "bearer" or "header"
|
||||
|
||||
def iter_aliases(self) -> Iterable[str]:
|
||||
"""返回大小写统一后的别名集合,包含枚举名本身。"""
|
||||
yield normalize_alias_value(self.api_format.value)
|
||||
for alias in self.aliases:
|
||||
normalized = normalize_alias_value(alias)
|
||||
if normalized:
|
||||
yield normalized
|
||||
|
||||
|
||||
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
APIFormat.CLAUDE: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE,
|
||||
aliases=("claude", "anthropic", "claude_compatible"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/claude"
|
||||
auth_header="x-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE_CLI,
|
||||
aliases=("claude_cli", "claude-cli"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
|
||||
auth_header="authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI,
|
||||
aliases=(
|
||||
"openai",
|
||||
"deepseek",
|
||||
"grok",
|
||||
"moonshot",
|
||||
"zhipu",
|
||||
"qwen",
|
||||
"baichuan",
|
||||
"minimax",
|
||||
"openai_compatible",
|
||||
),
|
||||
default_path="/v1/chat/completions",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/openai"
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI_CLI,
|
||||
aliases=("openai_cli", "responses"),
|
||||
default_path="/responses",
|
||||
path_prefix="",
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.GEMINI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI,
|
||||
aliases=("gemini", "google", "vertex"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.GEMINI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI_CLI,
|
||||
aliases=("gemini_cli", "gemini-cli"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 与 GEMINI 共享入口
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
}
|
||||
|
||||
# 对外只暴露只读视图,避免被随意修改
|
||||
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
|
||||
|
||||
|
||||
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
|
||||
"""获取指定格式的定义,不存在时抛出 KeyError。"""
|
||||
return API_FORMAT_DEFINITIONS[api_format]
|
||||
|
||||
|
||||
def list_api_format_definitions() -> List[ApiFormatDefinition]:
|
||||
"""返回所有定义的浅拷贝列表,供遍历使用。"""
|
||||
return list(API_FORMAT_DEFINITIONS.values())
|
||||
|
||||
|
||||
def build_alias_lookup() -> Dict[str, APIFormat]:
|
||||
"""
|
||||
构建 alias -> APIFormat 的查找表。
|
||||
每次调用都会返回新的 dict,避免可变全局引发并发问题。
|
||||
"""
|
||||
lookup: MutableMapping[str, APIFormat] = {}
|
||||
for definition in API_FORMAT_DEFINITIONS.values():
|
||||
for alias in definition.iter_aliases():
|
||||
lookup.setdefault(alias, definition.api_format)
|
||||
return dict(lookup)
|
||||
|
||||
|
||||
def get_default_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的上游默认请求路径。
|
||||
|
||||
可通过 Endpoint.custom_path 覆盖。
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
return definition.default_path if definition else "/"
|
||||
|
||||
|
||||
def get_local_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的本站入口路径。
|
||||
|
||||
本站入口路径 = path_prefix + default_path
|
||||
例如:path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
prefix = definition.path_prefix or ""
|
||||
return prefix + definition.default_path
|
||||
return "/"
|
||||
|
||||
|
||||
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
|
||||
"""
|
||||
获取该格式的认证配置。
|
||||
|
||||
Returns:
|
||||
(auth_header, auth_type) 元组
|
||||
- auth_header: 认证头名称
|
||||
- auth_type: "bearer" 或 "header"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
return definition.auth_header, definition.auth_type
|
||||
return "Authorization", "bearer"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _alias_lookup_cache() -> Dict[str, APIFormat]:
|
||||
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
|
||||
return build_alias_lookup()
|
||||
|
||||
|
||||
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
|
||||
"""根据别名查找 APIFormat,找不到时返回 None。"""
|
||||
if not value:
|
||||
return None
|
||||
normalized = normalize_alias_value(value)
|
||||
if not normalized:
|
||||
return None
|
||||
return _alias_lookup_cache().get(normalized)
|
||||
|
||||
|
||||
def resolve_api_format(
|
||||
value: Union[str, APIFormat, None],
|
||||
default: Optional[APIFormat] = None,
|
||||
) -> Optional[APIFormat]:
|
||||
"""
|
||||
将任意字符串/枚举值解析为 APIFormat。
|
||||
|
||||
Args:
|
||||
value: 可以是 APIFormat 或任意字符串/别名
|
||||
default: 未解析成功时返回的默认值
|
||||
"""
|
||||
if isinstance(value, APIFormat):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return default
|
||||
upper = stripped.upper()
|
||||
if upper in APIFormat.__members__:
|
||||
return APIFormat[upper]
|
||||
alias = resolve_api_format_alias(stripped)
|
||||
if alias:
|
||||
return alias
|
||||
return default
|
||||
|
||||
|
||||
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
|
||||
"""
|
||||
注册或覆盖 API 格式定义,允许运行时扩展。
|
||||
|
||||
Args:
|
||||
definition: 要注册的定义
|
||||
override: 若目标枚举已存在,是否允许覆盖
|
||||
"""
|
||||
existing = _DEFINITIONS.get(definition.api_format)
|
||||
if existing and not override:
|
||||
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
|
||||
_DEFINITIONS[definition.api_format] = definition
|
||||
_refresh_metadata_cache()
|
||||
|
||||
|
||||
def _refresh_metadata_cache():
|
||||
"""更新别名缓存,供注册函数调用。"""
|
||||
_alias_lookup_cache.cache_clear()
|
||||
|
||||
|
||||
def normalize_alias_value(value: str) -> str:
|
||||
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
|
||||
if value is None:
|
||||
return ""
|
||||
text = value.strip().lower()
|
||||
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
|
||||
text = re.sub(r"[^a-z0-9]+", "_", text)
|
||||
return text.strip("_")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 格式判断工具
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_cli_api_format(api_format: APIFormat) -> bool:
|
||||
"""
|
||||
判断是否为 CLI 透传格式。
|
||||
|
||||
Args:
|
||||
api_format: APIFormat 枚举值
|
||||
|
||||
Returns:
|
||||
True 如果是 CLI 格式
|
||||
"""
|
||||
from src.api.handlers.base.parsers import is_cli_format
|
||||
|
||||
return is_cli_format(api_format.value)
|
||||
115
src/core/batch_committer.py
Normal file
115
src/core/batch_committer.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
批量提交器 - 减少数据库 commit 次数,提升并发能力
|
||||
|
||||
核心思想:
|
||||
- 非关键数据(监控、统计)不立即 commit
|
||||
- 在后台定期批量 commit
|
||||
- 关键数据(计费)仍然立即 commit
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Set
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class BatchCommitter:
|
||||
"""批量提交管理器"""
|
||||
|
||||
def __init__(self, interval_seconds: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
interval_seconds: 批量提交间隔(秒)
|
||||
"""
|
||||
self.interval_seconds = interval_seconds
|
||||
self._pending_sessions: Set[Session] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动后台批量提交任务"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._batch_commit_loop())
|
||||
logger.info(f"批量提交器已启动,间隔: {self.interval_seconds}s")
|
||||
|
||||
async def stop(self):
|
||||
"""停止后台任务"""
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.info("批量提交器已停止")
|
||||
|
||||
def mark_dirty(self, session: Session):
|
||||
"""标记 Session 有待提交的更改"""
|
||||
self._pending_sessions.add(session)
|
||||
|
||||
async def _batch_commit_loop(self):
|
||||
"""后台批量提交循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.interval_seconds)
|
||||
await self._commit_all()
|
||||
except asyncio.CancelledError:
|
||||
# 关闭前提交所有待处理的
|
||||
await self._commit_all()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量提交出错: {e}")
|
||||
|
||||
async def _commit_all(self):
|
||||
"""提交所有待处理的 Session"""
|
||||
async with self._lock:
|
||||
if not self._pending_sessions:
|
||||
return
|
||||
|
||||
sessions_to_commit = list(self._pending_sessions)
|
||||
self._pending_sessions.clear()
|
||||
|
||||
committed = 0
|
||||
failed = 0
|
||||
|
||||
for session in sessions_to_commit:
|
||||
try:
|
||||
session.commit()
|
||||
committed += 1
|
||||
except Exception as e:
|
||||
logger.error(f"提交 Session 失败: {e}")
|
||||
try:
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
failed += 1
|
||||
|
||||
if committed > 0:
|
||||
logger.debug(f"批量提交完成: {committed} 个 Session")
|
||||
if failed > 0:
|
||||
logger.warning(f"批量提交失败: {failed} 个 Session")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_batch_committer: BatchCommitter = None
|
||||
|
||||
|
||||
def get_batch_committer() -> BatchCommitter:
|
||||
"""获取全局批量提交器"""
|
||||
global _batch_committer
|
||||
if _batch_committer is None:
|
||||
_batch_committer = BatchCommitter(interval_seconds=1.0)
|
||||
return _batch_committer
|
||||
|
||||
|
||||
async def init_batch_committer():
|
||||
"""初始化并启动批量提交器"""
|
||||
committer = get_batch_committer()
|
||||
await committer.start()
|
||||
|
||||
|
||||
async def shutdown_batch_committer():
|
||||
"""关闭批量提交器"""
|
||||
committer = get_batch_committer()
|
||||
await committer.stop()
|
||||
174
src/core/cache_service.py
Normal file
174
src/core/cache_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
缓存服务 - 统一的缓存抽象层
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""缓存服务"""
|
||||
|
||||
@staticmethod
|
||||
async def get(key: str) -> Optional[Any]:
|
||||
"""
|
||||
从缓存获取数据
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存的值,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return None
|
||||
|
||||
value = await redis.get(key)
|
||||
if value:
|
||||
# 尝试 JSON 反序列化
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存读取失败: {key} - {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def set(key: str, value: Any, ttl_seconds: int = 60) -> bool:
|
||||
"""
|
||||
设置缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl_seconds: 过期时间(秒),默认 60 秒
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
# JSON 序列化
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
elif not isinstance(value, (str, bytes)):
|
||||
value = str(value)
|
||||
|
||||
await redis.setex(key, ttl_seconds, value)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存写入失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def delete(key: str) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
await redis.delete(key)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存删除失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def exists(key: str) -> bool:
|
||||
"""
|
||||
检查缓存是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
return await redis.exists(key) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存检查失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 缓存键前缀
|
||||
class CacheKeys:
|
||||
"""缓存键定义"""
|
||||
|
||||
# User 缓存(TTL 60秒)
|
||||
USER_BY_ID = "user:id:{user_id}"
|
||||
USER_BY_EMAIL = "user:email:{email}"
|
||||
|
||||
# API Key 缓存(TTL 30秒)
|
||||
APIKEY_HASH = "apikey:hash:{key_hash}"
|
||||
APIKEY_AUTH = "apikey:auth:{key_hash}" # 认证结果缓存
|
||||
|
||||
# Provider 配置缓存(TTL 300秒)
|
||||
PROVIDER_BY_ID = "provider:id:{provider_id}"
|
||||
ENDPOINT_BY_ID = "endpoint:id:{endpoint_id}"
|
||||
API_KEY_BY_ID = "api_key:id:{api_key_id}"
|
||||
|
||||
@staticmethod
|
||||
def user_by_id(user_id: str) -> str:
|
||||
"""User ID 缓存键"""
|
||||
return CacheKeys.USER_BY_ID.format(user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
def user_by_email(email: str) -> str:
|
||||
"""User Email 缓存键"""
|
||||
return CacheKeys.USER_BY_EMAIL.format(email=email)
|
||||
|
||||
@staticmethod
|
||||
def apikey_hash(key_hash: str) -> str:
|
||||
"""API Key Hash 缓存键"""
|
||||
return CacheKeys.APIKEY_HASH.format(key_hash=key_hash)
|
||||
|
||||
@staticmethod
|
||||
def apikey_auth(key_hash: str) -> str:
|
||||
"""API Key 认证结果缓存键"""
|
||||
return CacheKeys.APIKEY_AUTH.format(key_hash=key_hash)
|
||||
|
||||
@staticmethod
|
||||
def provider_by_id(provider_id: str) -> str:
|
||||
"""Provider ID 缓存键"""
|
||||
return CacheKeys.PROVIDER_BY_ID.format(provider_id=provider_id)
|
||||
|
||||
@staticmethod
|
||||
def endpoint_by_id(endpoint_id: str) -> str:
|
||||
"""Endpoint ID 缓存键"""
|
||||
return CacheKeys.ENDPOINT_BY_ID.format(endpoint_id=endpoint_id)
|
||||
|
||||
@staticmethod
|
||||
def api_key_by_id(api_key_id: str) -> str:
|
||||
"""API Key ID 缓存键"""
|
||||
return CacheKeys.API_KEY_BY_ID.format(api_key_id=api_key_id)
|
||||
133
src/core/cache_utils.py
Normal file
133
src/core/cache_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
缓存工具类
|
||||
|
||||
提供同步缓存接口,用于不适合使用异步缓存的场景
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class SyncLRUCache:
|
||||
"""
|
||||
同步 LRU 缓存(带 TTL 和线程安全)
|
||||
|
||||
用于需要同步访问的场景,如 ModelMapperMiddleware
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, ttl: int = 300) -> None:
|
||||
"""
|
||||
初始化缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存条目数
|
||||
ttl: 过期时间(秒)
|
||||
"""
|
||||
self._cache: OrderedDict = OrderedDict()
|
||||
self._expiry: Dict[Any, float] = {}
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _is_expired(self, key: Any) -> bool:
|
||||
"""检查 key 是否过期(调用者需确保已持有锁)"""
|
||||
if key in self._expiry:
|
||||
return time.time() > self._expiry[key]
|
||||
return False
|
||||
|
||||
def _delete_key(self, key: Any) -> None:
|
||||
"""删除 key(调用者需确保已持有锁)"""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
if key in self._expiry:
|
||||
del self._expiry[key]
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
"""获取缓存值"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return default
|
||||
|
||||
if self._is_expired(key):
|
||||
self._delete_key(key)
|
||||
return default
|
||||
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
def set(self, key: Any, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""设置缓存值"""
|
||||
with self._lock:
|
||||
if ttl is None:
|
||||
ttl = self.ttl
|
||||
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
self._cache[key] = value
|
||||
self._expiry[key] = time.time() + ttl
|
||||
|
||||
while len(self._cache) > self.max_size:
|
||||
oldest = next(iter(self._cache))
|
||||
self._delete_key(oldest)
|
||||
|
||||
def delete(self, key: Any) -> None:
|
||||
"""删除缓存值"""
|
||||
with self._lock:
|
||||
self._delete_key(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空缓存"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._expiry.clear()
|
||||
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
"""检查 key 是否存在"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return False
|
||||
if self._is_expired(key):
|
||||
self._delete_key(key)
|
||||
return False
|
||||
return True
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
"""获取缓存值(通过索引)"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
raise KeyError(key)
|
||||
|
||||
if self._is_expired(key):
|
||||
self._delete_key(key)
|
||||
raise KeyError(key)
|
||||
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
def __setitem__(self, key: Any, value: Any) -> None:
|
||||
"""设置缓存值(通过索引)"""
|
||||
self.set(key, value)
|
||||
|
||||
def __delitem__(self, key: Any) -> None:
|
||||
"""删除缓存值(通过索引)"""
|
||||
self.delete(key)
|
||||
|
||||
def keys(self):
|
||||
"""返回所有未过期的 key"""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
return [
|
||||
k for k in self._cache.keys() if k not in self._expiry or now <= self._expiry[k]
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"ttl": self.ttl,
|
||||
}
|
||||
168
src/core/context.py
Normal file
168
src/core/context.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
统一的请求上下文
|
||||
|
||||
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
|
||||
这确保了数据在各层之间传递时不会丢失。
|
||||
|
||||
使用方式:
|
||||
1. Pipeline 层创建 RequestContext
|
||||
2. 各层通过 context 访问和更新信息
|
||||
3. Adapter 层使用 context 记录 Usage
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""
|
||||
请求上下文 - 贯穿整个请求生命周期
|
||||
|
||||
设计原则:
|
||||
1. 在请求开始时创建,包含所有已知信息
|
||||
2. 在请求执行过程中逐步填充 Provider 信息
|
||||
3. 在请求结束时用于记录 Usage
|
||||
"""
|
||||
|
||||
# ==================== 请求标识 ====================
|
||||
request_id: str
|
||||
|
||||
# ==================== 认证信息 ====================
|
||||
user: Any # User model
|
||||
api_key: Any # ApiKey model
|
||||
db: Any # Database session
|
||||
|
||||
# ==================== 请求信息 ====================
|
||||
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
|
||||
model: str # 用户请求的模型名
|
||||
is_stream: bool = False
|
||||
|
||||
# ==================== 原始请求 ====================
|
||||
original_headers: Dict[str, str] = field(default_factory=dict)
|
||||
original_body: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# ==================== 客户端信息 ====================
|
||||
client_ip: str = "unknown"
|
||||
user_agent: str = ""
|
||||
|
||||
# ==================== 计时 ====================
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
# ==================== Provider 信息(请求执行后填充)====================
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
provider_api_key_id: Optional[str] = None
|
||||
|
||||
# ==================== 模型映射信息 ====================
|
||||
resolved_model: Optional[str] = None # 映射后的模型名
|
||||
original_model: Optional[str] = None # 原始模型名(用于价格计算)
|
||||
|
||||
# ==================== 请求/响应头 ====================
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# ==================== 追踪信息 ====================
|
||||
attempt_id: Optional[str] = None
|
||||
|
||||
# ==================== 能力需求 ====================
|
||||
capability_requirements: Dict[str, bool] = field(default_factory=dict)
|
||||
# 运行时计算的能力需求,来源于:
|
||||
# 1. 用户 model_capability_settings
|
||||
# 2. 用户 ApiKey.force_capabilities
|
||||
# 3. 请求头 X-Require-Capability
|
||||
# 4. 失败重试时动态添加
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
db: Any,
|
||||
user: Any,
|
||||
api_key: Any,
|
||||
api_format: str,
|
||||
model: str,
|
||||
is_stream: bool = False,
|
||||
original_headers: Optional[Dict[str, str]] = None,
|
||||
original_body: Optional[Dict[str, Any]] = None,
|
||||
client_ip: str = "unknown",
|
||||
user_agent: str = "",
|
||||
request_id: Optional[str] = None,
|
||||
) -> "RequestContext":
|
||||
"""创建请求上下文"""
|
||||
return cls(
|
||||
request_id=request_id or str(uuid.uuid4()),
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
api_format=api_format,
|
||||
model=model,
|
||||
is_stream=is_stream,
|
||||
original_headers=original_headers or {},
|
||||
original_body=original_body or {},
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
original_model=model, # 初始时原始模型等于请求模型
|
||||
)
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
provider_id: str,
|
||||
endpoint_id: str,
|
||||
provider_api_key_id: str,
|
||||
resolved_model: Optional[str] = None,
|
||||
) -> None:
|
||||
"""更新 Provider 信息(请求执行后调用)"""
|
||||
self.provider_name = provider_name
|
||||
self.provider_id = provider_id
|
||||
self.endpoint_id = endpoint_id
|
||||
self.provider_api_key_id = provider_api_key_id
|
||||
if resolved_model:
|
||||
self.resolved_model = resolved_model
|
||||
|
||||
def update_headers(
|
||||
self,
|
||||
*,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""更新请求/响应头"""
|
||||
if request_headers:
|
||||
self.provider_request_headers = request_headers
|
||||
if response_headers:
|
||||
self.provider_response_headers = response_headers
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
"""计算已经过的时间(毫秒)"""
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
@property
|
||||
def effective_model(self) -> str:
|
||||
"""获取有效的模型名(映射后优先)"""
|
||||
return self.resolved_model or self.model
|
||||
|
||||
@property
|
||||
def billing_model(self) -> str:
|
||||
"""获取计费模型名(原始模型优先)"""
|
||||
return self.original_model or self.model
|
||||
|
||||
def to_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""转换为元数据字典(用于 Usage 记录)"""
|
||||
return {
|
||||
"api_format": self.api_format,
|
||||
"provider": self.provider_name or "unknown",
|
||||
"model": self.effective_model,
|
||||
"original_model": self.billing_model,
|
||||
"provider_id": self.provider_id,
|
||||
"provider_endpoint_id": self.endpoint_id,
|
||||
"provider_api_key_id": self.provider_api_key_id,
|
||||
"provider_request_headers": self.provider_request_headers,
|
||||
"provider_response_headers": self.provider_response_headers,
|
||||
"attempt_id": self.attempt_id,
|
||||
}
|
||||
166
src/core/crypto.py
Normal file
166
src/core/crypto.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
加密工具模块
|
||||
提供API密钥的加密和解密功能
|
||||
|
||||
安全说明:
|
||||
- 生产环境必须设置独立的 ENCRYPTION_KEY
|
||||
- 加密密钥应独立于 JWT_SECRET_KEY,避免密钥轮换问题
|
||||
- 使用 PBKDF2 派生密钥时会使用应用级 salt
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from ..config import config
|
||||
from ..core.exceptions import DecryptionException
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class CryptoService:
|
||||
"""
|
||||
加密服务
|
||||
|
||||
提供对称加密功能,用于保护 Provider API Key 等敏感数据。
|
||||
使用 Fernet(AES-128-CBC + HMAC-SHA256)确保数据机密性和完整性。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_cipher = None
|
||||
_key_source: str = "unknown" # 记录密钥来源,用于调试
|
||||
|
||||
# 应用级 salt(基于应用名称生成,比硬编码更安全)
|
||||
# 注意:更改此值会导致所有已加密数据无法解密
|
||||
APP_SALT = hashlib.sha256(b"aether-v1").digest()[:16]
|
||||
|
||||
def __new__(cls) -> "CryptoService":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化加密服务"""
|
||||
logger.info("初始化加密服务")
|
||||
|
||||
encryption_key = config.encryption_key
|
||||
|
||||
if not encryption_key:
|
||||
if config.environment == "production":
|
||||
raise ValueError(
|
||||
"ENCRYPTION_KEY must be set in production! "
|
||||
"Use 'python generate_keys.py' to generate a secure key."
|
||||
)
|
||||
# 开发环境:使用固定的开发密钥
|
||||
logger.warning("[DEV] 未设置 ENCRYPTION_KEY,使用开发环境默认密钥。")
|
||||
encryption_key = "dev-encryption-key-do-not-use-in-production"
|
||||
self._key_source = "development_default"
|
||||
else:
|
||||
self._key_source = "environment_variable"
|
||||
|
||||
# 派生 Fernet 密钥
|
||||
key = self._derive_fernet_key(encryption_key)
|
||||
|
||||
self._cipher = Fernet(key)
|
||||
logger.info(f"加密服务初始化成功 (key_source={self._key_source})")
|
||||
|
||||
def _derive_fernet_key(self, encryption_key: str) -> bytes:
|
||||
"""
|
||||
从密码/密钥派生 Fernet 兼容的密钥
|
||||
|
||||
Args:
|
||||
encryption_key: 原始密钥字符串
|
||||
|
||||
Returns:
|
||||
Fernet 兼容的 base64 编码密钥
|
||||
"""
|
||||
# 首先尝试直接作为 Fernet 密钥使用
|
||||
try:
|
||||
key_bytes = (
|
||||
encryption_key.encode() if isinstance(encryption_key, str) else encryption_key
|
||||
)
|
||||
# 验证是否为有效的 Fernet 密钥(32 字节 base64 编码)
|
||||
Fernet(key_bytes)
|
||||
return key_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 不是有效的 Fernet 密钥,使用 PBKDF2 派生
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=self.APP_SALT,
|
||||
iterations=100000,
|
||||
)
|
||||
derived_key = kdf.derive(encryption_key.encode())
|
||||
return base64.urlsafe_b64encode(derived_key)
|
||||
|
||||
def encrypt(self, plaintext: str) -> str:
|
||||
"""
|
||||
加密字符串
|
||||
|
||||
Args:
|
||||
plaintext: 明文字符串
|
||||
|
||||
Returns:
|
||||
加密后的字符串(base64编码)
|
||||
"""
|
||||
if not plaintext:
|
||||
return plaintext
|
||||
|
||||
try:
|
||||
encrypted = self._cipher.encrypt(plaintext.encode())
|
||||
return base64.urlsafe_b64encode(encrypted).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Encryption failed: {e}")
|
||||
raise ValueError("Failed to encrypt data")
|
||||
|
||||
def decrypt(self, ciphertext: str, silent: bool = False) -> str:
|
||||
"""
|
||||
解密字符串
|
||||
|
||||
Args:
|
||||
ciphertext: 加密的字符串(base64编码)
|
||||
silent: 是否静默模式(失败时不打印错误日志)
|
||||
|
||||
Returns:
|
||||
解密后的明文字符串
|
||||
|
||||
Raises:
|
||||
DecryptionException: 解密失败时抛出异常
|
||||
"""
|
||||
if not ciphertext:
|
||||
return ciphertext
|
||||
|
||||
try:
|
||||
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
|
||||
decrypted = self._cipher.decrypt(encrypted)
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
if not silent:
|
||||
logger.error(f"Decryption failed: {e}")
|
||||
# 抛出自定义异常,方便在上层通过类型判断是否需要打印堆栈
|
||||
raise DecryptionException(
|
||||
message=f"解密失败: {str(e)}。可能原因: ENCRYPTION_KEY 已改变或数据已损坏。解决方案: 请在管理面板重新设置 Provider API Key。",
|
||||
details={"original_error": str(e), "key_source": self._key_source},
|
||||
)
|
||||
|
||||
def hash_api_key(self, api_key: str) -> str:
|
||||
"""
|
||||
对API密钥进行哈希(用于查找)
|
||||
|
||||
Args:
|
||||
api_key: API密钥明文
|
||||
|
||||
Returns:
|
||||
哈希后的值
|
||||
"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
|
||||
# 创建全局加密服务实例
|
||||
crypto_service = CryptoService()
|
||||
32
src/core/enums.py
Normal file
32
src/core/enums.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
统一的枚举定义
|
||||
避免重复定义造成的不一致
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class APIFormat(Enum):
|
||||
"""API格式枚举 - 决定请求/响应的处理方式"""
|
||||
|
||||
CLAUDE = "CLAUDE" # Claude API 格式
|
||||
OPENAI = "OPENAI" # OpenAI API 格式
|
||||
CLAUDE_CLI = "CLAUDE_CLI" # Claude CLI API 格式(使用 authorization: Bearer)
|
||||
OPENAI_CLI = "OPENAI_CLI" # OpenAI CLI/Responses API 格式(用于 Claude Code 等客户端)
|
||||
GEMINI = "GEMINI" # Google Gemini API 格式
|
||||
GEMINI_CLI = "GEMINI_CLI" # Gemini CLI API 格式
|
||||
|
||||
|
||||
class UserRole(Enum):
|
||||
"""用户角色枚举"""
|
||||
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class ProviderBillingType(Enum):
|
||||
"""提供商计费类型"""
|
||||
|
||||
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
|
||||
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
|
||||
FREE_TIER = "free_tier" # 免费额度
|
||||
675
src/core/exceptions.py
Normal file
675
src/core/exceptions.py
Normal file
@@ -0,0 +1,675 @@
|
||||
"""
|
||||
统一的异常处理和错误响应定义
|
||||
|
||||
安全说明:
|
||||
- 生产环境不返回详细错误信息,避免信息泄露
|
||||
- 使用错误 ID 关联日志,便于排查问题
|
||||
- 开发环境可返回详细信息用于调试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..config import config
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
# Pydantic 错误消息中英文翻译映射
|
||||
PYDANTIC_ERROR_TRANSLATIONS = {
|
||||
# 字符串验证
|
||||
r"String should have at least (\d+) characters?": r"字符串长度至少需要 \1 个字符",
|
||||
r"String should have at most (\d+) characters?": r"字符串长度最多 \1 个字符",
|
||||
r"string_too_short": "字符串长度不足",
|
||||
r"string_too_long": "字符串长度超出限制",
|
||||
# 必填字段
|
||||
r"Field required": "此字段为必填项",
|
||||
r"field required": "此字段为必填项",
|
||||
r"missing": "缺少必填字段",
|
||||
# 类型错误
|
||||
r"Input should be a valid string": "输入应为有效的字符串",
|
||||
r"Input should be a valid integer": "输入应为有效的整数",
|
||||
r"Input should be a valid number": "输入应为有效的数字",
|
||||
r"Input should be a valid boolean": "输入应为布尔值",
|
||||
r"Input should be a valid email address": "输入应为有效的邮箱地址",
|
||||
r"Input should be a valid list": "输入应为有效的列表",
|
||||
r"Input should be a valid dictionary": "输入应为有效的字典",
|
||||
# 数值验证
|
||||
r"Input should be greater than (\d+)": r"数值应大于 \1",
|
||||
r"Input should be greater than or equal to (\d+)": r"数值应大于或等于 \1",
|
||||
r"Input should be less than (\d+)": r"数值应小于 \1",
|
||||
r"Input should be less than or equal to (\d+)": r"数值应小于或等于 \1",
|
||||
# 枚举验证
|
||||
r"Input should be (.+)": r"输入应为 \1",
|
||||
# 其他
|
||||
r"value is not a valid email address": "邮箱地址格式无效",
|
||||
r"invalid.*email": "邮箱地址格式无效",
|
||||
r"Extra inputs are not permitted": "不允许额外的字段",
|
||||
r"Value error, (.+)": r"\1", # 自定义验证器的错误直接使用
|
||||
}
|
||||
|
||||
# 字段名中英文翻译映射
|
||||
FIELD_NAME_TRANSLATIONS = {
|
||||
"password": "密码",
|
||||
"username": "用户名",
|
||||
"email": "邮箱",
|
||||
"role": "角色",
|
||||
"quota_usd": "配额",
|
||||
"name": "名称",
|
||||
"title": "标题",
|
||||
"content": "内容",
|
||||
"ip_address": "IP地址",
|
||||
"reason": "原因",
|
||||
"ttl": "过期时间",
|
||||
"enabled": "启用状态",
|
||||
"fixed_limit": "固定限制",
|
||||
"old_password": "旧密码",
|
||||
"new_password": "新密码",
|
||||
"allowed_providers": "允许的提供商",
|
||||
"allowed_models": "允许的模型",
|
||||
"rate_limit": "速率限制",
|
||||
"expire_days": "过期天数",
|
||||
"priority": "优先级",
|
||||
"type": "类型",
|
||||
"is_active": "激活状态",
|
||||
"is_pinned": "置顶状态",
|
||||
"start_time": "开始时间",
|
||||
"end_time": "结束时间",
|
||||
}
|
||||
|
||||
|
||||
def translate_pydantic_error(error: Dict[str, Any]) -> str:
|
||||
"""
|
||||
将 Pydantic 验证错误翻译为中文
|
||||
|
||||
Args:
|
||||
error: Pydantic 错误字典,包含 loc, msg, type 等字段
|
||||
|
||||
Returns:
|
||||
翻译后的中文错误消息
|
||||
"""
|
||||
# 获取字段名
|
||||
loc = error.get("loc", [])
|
||||
field = str(loc[0]) if loc else ""
|
||||
field_zh = FIELD_NAME_TRANSLATIONS.get(field, field)
|
||||
|
||||
# 获取错误消息
|
||||
msg = error.get("msg", "验证失败")
|
||||
|
||||
# 尝试翻译错误消息
|
||||
translated_msg = msg
|
||||
for pattern, replacement in PYDANTIC_ERROR_TRANSLATIONS.items():
|
||||
if re.search(pattern, msg, re.IGNORECASE):
|
||||
translated_msg = re.sub(pattern, replacement, msg, flags=re.IGNORECASE)
|
||||
break
|
||||
|
||||
# 组合字段名和错误消息
|
||||
if field_zh:
|
||||
return f"{field_zh}: {translated_msg}"
|
||||
return translated_msg
|
||||
|
||||
|
||||
def translate_pydantic_errors(errors: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
翻译多个 Pydantic 验证错误
|
||||
|
||||
Args:
|
||||
errors: Pydantic 错误列表
|
||||
|
||||
Returns:
|
||||
翻译后的错误消息,多个错误用分号分隔
|
||||
"""
|
||||
if not errors:
|
||||
return "请求数据验证失败"
|
||||
|
||||
translated = [translate_pydantic_error(e) for e in errors]
|
||||
return "; ".join(translated)
|
||||
|
||||
|
||||
|
||||
# 延迟导入韧性管理器,避免循环导入
|
||||
def get_resilience_manager():
|
||||
try:
|
||||
from ..core.resilience import resilience_manager
|
||||
|
||||
return resilience_manager
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
class ProxyException(HTTPException):
|
||||
"""代理服务基础异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
error_type: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.error_type = error_type
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(status_code=status_code, detail=message)
|
||||
|
||||
|
||||
class ProviderException(ProxyException):
|
||||
"""提供商相关异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.request_metadata = request_metadata # 保存元数据以便传递
|
||||
details = {"provider": provider_name} if provider_name else {}
|
||||
details.update(kwargs)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
error_type="provider_error",
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class ProviderNotAvailableException(ProviderException):
|
||||
"""提供商不可用"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
|
||||
class ProviderTimeoutException(ProviderException):
|
||||
"""提供商请求超时"""
|
||||
|
||||
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
|
||||
super().__init__(
|
||||
message=f"提供商 '{provider_name}' 请求超时({timeout}秒)",
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
class ProviderAuthException(ProviderException):
|
||||
"""提供商认证失败"""
|
||||
|
||||
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
|
||||
super().__init__(
|
||||
message=f"提供商 '{provider_name}' 认证失败,请检查API密钥",
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
|
||||
class ProviderRateLimitException(ProviderException):
|
||||
"""提供商限流"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None, # 添加响应头
|
||||
retry_after: Optional[int] = None, # 添加重试时间
|
||||
):
|
||||
self.response_headers = response_headers or {} # 保存响应头
|
||||
self.retry_after = retry_after # 保存重试时间
|
||||
super().__init__(
|
||||
message=message, provider_name=provider_name, request_metadata=request_metadata
|
||||
)
|
||||
|
||||
|
||||
class QuotaExceededException(ProxyException):
|
||||
"""配额超限"""
|
||||
|
||||
def __init__(self, quota_type: str = "tokens", remaining: Optional[float] = None):
|
||||
message = f"{quota_type}配额已用尽"
|
||||
if remaining is not None:
|
||||
message += f"(剩余: {remaining})"
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
error_type="quota_exceeded",
|
||||
message=message,
|
||||
details={"quota_type": quota_type, "remaining": remaining},
|
||||
)
|
||||
|
||||
|
||||
class RateLimitException(ProxyException):
|
||||
"""速率限制"""
|
||||
|
||||
def __init__(self, limit: int, window: str = "minute"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
error_type="rate_limit",
|
||||
message=f"请求过于频繁,限制为每{window} {limit}次",
|
||||
details={"limit": limit, "window": window},
|
||||
)
|
||||
|
||||
|
||||
class ConcurrencyLimitError(ProxyException):
|
||||
"""并发限制异常"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
|
||||
):
|
||||
details = {}
|
||||
if endpoint_id:
|
||||
details["endpoint_id"] = endpoint_id
|
||||
if key_id:
|
||||
details["key_id"] = key_id
|
||||
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
error_type="concurrency_limit",
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class ModelNotSupportedException(ProxyException):
|
||||
"""模型不支持"""
|
||||
|
||||
def __init__(self, model: str, provider_name: Optional[str] = None):
|
||||
message = f"模型 '{model}' 不受支持"
|
||||
if provider_name:
|
||||
message = f"提供商 '{provider_name}' 不支持模型 '{model}'"
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
error_type="model_not_supported",
|
||||
message=message,
|
||||
details={"model": model, "provider": provider_name},
|
||||
)
|
||||
|
||||
|
||||
class StreamingNotSupportedException(ProxyException):
|
||||
"""流式请求不支持"""
|
||||
|
||||
def __init__(self, model: str, provider_name: Optional[str] = None):
|
||||
if provider_name:
|
||||
message = f"模型 '{model}' 在提供商 '{provider_name}' 上不支持流式请求"
|
||||
else:
|
||||
message = f"模型 '{model}' 不支持流式请求"
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
error_type="streaming_not_supported",
|
||||
message=message,
|
||||
details={"model": model, "provider": provider_name},
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestException(ProxyException):
|
||||
"""无效请求"""
|
||||
|
||||
def __init__(self, message: str, field: Optional[str] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
error_type="invalid_request",
|
||||
message=message,
|
||||
details={"field": field} if field else {},
|
||||
)
|
||||
|
||||
|
||||
class NotFoundException(ProxyException):
|
||||
"""资源未找到"""
|
||||
|
||||
def __init__(self, message: str, resource_type: Optional[str] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_type="not_found",
|
||||
message=message,
|
||||
details={"resource_type": resource_type} if resource_type else {},
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenException(ProxyException):
|
||||
"""权限不足"""
|
||||
|
||||
def __init__(self, message: str, required_role: Optional[str] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_type="forbidden",
|
||||
message=message,
|
||||
details={"required_role": required_role} if required_role else {},
|
||||
)
|
||||
|
||||
|
||||
class DecryptionException(ProxyException):
|
||||
"""解密失败异常 - 已知的配置问题,不需要打印堆栈"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_type="decryption_error",
|
||||
message=message,
|
||||
details=details or {},
|
||||
)
|
||||
|
||||
|
||||
class JSONParseException(ProviderException):
|
||||
"""JSON解析错误"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str,
|
||||
original_error: str,
|
||||
response_content: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
details = {
|
||||
"original_error": original_error,
|
||||
"content_type": content_type,
|
||||
}
|
||||
if response_content and len(response_content) > 500:
|
||||
# 截断长内容,但保留头尾
|
||||
details["response_preview"] = f"{response_content[:200]}...{response_content[-200:]}"
|
||||
elif response_content:
|
||||
details["response_content"] = response_content
|
||||
|
||||
super().__init__(
|
||||
message=f"提供商 '{provider_name}' 返回了无效的JSON响应",
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
**details,
|
||||
)
|
||||
|
||||
|
||||
class EmptyStreamException(ProviderException):
|
||||
"""流式响应为空异常 - 上游返回200但没有发送任何数据"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str,
|
||||
chunk_count: int = 0,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
message=f"提供商 '{provider_name}' 返回了空的流式响应(status=200 但无数据)",
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddedErrorException(ProviderException):
|
||||
"""响应体内嵌套错误异常 - HTTP 状态码正常但响应体包含错误信息
|
||||
|
||||
用于处理某些 Provider(如 Gemini)返回 HTTP 200 但在响应体中包含错误的情况。
|
||||
这类错误需要触发重试逻辑。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str,
|
||||
error_code: Optional[int] = None,
|
||||
error_message: Optional[str] = None,
|
||||
error_status: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
message = f"提供商 '{provider_name}' 返回了嵌套错误"
|
||||
if error_code:
|
||||
message += f" (code={error_code})"
|
||||
if error_message:
|
||||
message += f": {error_message}"
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
error_code=error_code,
|
||||
error_status=error_status,
|
||||
)
|
||||
self.error_code = error_code
|
||||
self.error_message = error_message
|
||||
self.error_status = error_status
|
||||
|
||||
|
||||
class UpstreamClientException(ProxyException):
|
||||
"""上游返回的客户端错误异常 - HTTP 4xx 错误,不应该重试
|
||||
|
||||
用于处理上游 Provider 返回的客户端错误(如图片处理失败、无效请求等)。
|
||||
这类错误是由用户请求本身的问题导致的,换 Provider 也无济于事,不应该重试。
|
||||
|
||||
常见场景:
|
||||
- 图片处理失败(图片过大、格式不支持等)
|
||||
- 请求参数无效
|
||||
- 消息内容违规
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
status_code: int = 400,
|
||||
error_type: Optional[str] = None,
|
||||
upstream_error: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
self.upstream_error = upstream_error
|
||||
self.request_metadata = request_metadata
|
||||
details = {}
|
||||
if provider_name:
|
||||
details["provider"] = provider_name
|
||||
if error_type:
|
||||
details["upstream_error_type"] = error_type
|
||||
if upstream_error:
|
||||
details["upstream_error"] = upstream_error
|
||||
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
error_type="upstream_client_error",
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse:
|
||||
"""统一的错误响应格式化器"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
error_type: str,
|
||||
message: str,
|
||||
status_code: int = 500,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> JSONResponse:
|
||||
"""创建标准错误响应"""
|
||||
error_body = {"error": {"type": error_type, "message": message}}
|
||||
|
||||
if details:
|
||||
error_body["error"]["details"] = details
|
||||
|
||||
# 记录错误日志
|
||||
logger.error(f"Error response: {error_type} - {message}")
|
||||
|
||||
return JSONResponse(status_code=status_code, content=error_body)
|
||||
|
||||
@staticmethod
|
||||
def from_exception(e: Exception) -> JSONResponse:
|
||||
"""
|
||||
从异常创建错误响应
|
||||
|
||||
安全说明:
|
||||
- 生产环境只返回错误 ID,不暴露详细信息
|
||||
- 开发环境返回完整错误信息用于调试
|
||||
- 所有错误都记录到日志,通过错误 ID 关联
|
||||
"""
|
||||
if isinstance(e, ProxyException):
|
||||
return ErrorResponse.create(
|
||||
error_type=e.error_type,
|
||||
message=e.message,
|
||||
status_code=e.status_code,
|
||||
details=e.details,
|
||||
)
|
||||
elif isinstance(e, HTTPException):
|
||||
return ErrorResponse.create(
|
||||
error_type="http_error", message=str(e.detail), status_code=e.status_code
|
||||
)
|
||||
else:
|
||||
# 未知异常,使用错误 ID 机制
|
||||
error_id = str(uuid.uuid4())[:8] # 短 ID,便于用户报告
|
||||
error_type_name = type(e).__name__
|
||||
error_message = str(e)
|
||||
|
||||
# 始终记录完整错误到日志
|
||||
logger.error(f"[{error_id}] Unexpected error: {error_type_name}: {error_message}")
|
||||
|
||||
# 根据环境决定返回的详细程度
|
||||
is_development = config.environment in ("development", "test", "testing")
|
||||
|
||||
if is_development:
|
||||
# 开发环境:返回完整错误信息
|
||||
return ErrorResponse.create(
|
||||
error_type="internal_error",
|
||||
message=f"内部服务器错误: {error_type_name}: {error_message}",
|
||||
status_code=500,
|
||||
details={
|
||||
"error_id": error_id,
|
||||
"error_type": error_type_name,
|
||||
"error": error_message,
|
||||
"traceback": traceback.format_exc().split("\n"),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 生产环境:只返回错误 ID
|
||||
return ErrorResponse.create(
|
||||
error_type="internal_error",
|
||||
message="内部服务器错误",
|
||||
status_code=500,
|
||||
details={
|
||||
"error_id": error_id,
|
||||
"support_info": "请联系管理员并提供此错误 ID",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def provider_error(provider_name: str, error: Exception) -> JSONResponse:
|
||||
"""提供商错误响应 - 基于异常类型判断"""
|
||||
# 基于异常类型判断,更可靠
|
||||
if isinstance(error, (asyncio.TimeoutError, httpx.TimeoutException)):
|
||||
return ErrorResponse.from_exception(ProviderTimeoutException(provider_name, 60))
|
||||
elif isinstance(error, (httpx.HTTPStatusError,)):
|
||||
if error.response.status_code == 401:
|
||||
return ErrorResponse.from_exception(ProviderAuthException(provider_name))
|
||||
elif error.response.status_code == 429:
|
||||
return ErrorResponse.from_exception(
|
||||
ProviderRateLimitException(
|
||||
message=f"提供商 '{provider_name}' 速率限制",
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
elif isinstance(error, (httpx.ConnectError, httpx.NetworkError)):
|
||||
return ErrorResponse.create(
|
||||
error_type="provider_connection_error",
|
||||
message=f"无法连接到提供商 {provider_name}",
|
||||
status_code=503,
|
||||
details={"provider": provider_name, "error": "Connection failed"},
|
||||
)
|
||||
# 如果异常类型无法判断,再通过字符串匹配作为备用
|
||||
elif "auth" in str(error).lower() or "401" in str(error):
|
||||
return ErrorResponse.from_exception(ProviderAuthException(provider_name))
|
||||
elif "rate limit" in str(error).lower() or "429" in str(error):
|
||||
return ErrorResponse.from_exception(
|
||||
ProviderRateLimitException(
|
||||
message=f"提供商 '{provider_name}' 速率限制",
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return ErrorResponse.create(
|
||||
error_type="provider_error",
|
||||
message=f"提供商请求失败: {str(error)}",
|
||||
status_code=503,
|
||||
details={
|
||||
"provider": provider_name,
|
||||
"error": str(error),
|
||||
"error_type": type(error).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ExceptionHandlers:
|
||||
"""FastAPI异常处理器"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_proxy_exception(request, exc: ProxyException):
|
||||
"""处理代理异常"""
|
||||
return ErrorResponse.from_exception(exc)
|
||||
|
||||
@staticmethod
|
||||
async def handle_http_exception(request, exc: HTTPException):
|
||||
"""处理HTTP异常"""
|
||||
return ErrorResponse.from_exception(exc)
|
||||
|
||||
@staticmethod
|
||||
async def handle_generic_exception(request, exc: Exception):
|
||||
"""处理通用异常 - 集成韧性管理"""
|
||||
|
||||
# 首先检查是否为HTTPException,如果是则委托给HTTP异常处理器
|
||||
if isinstance(exc, HTTPException):
|
||||
return await ExceptionHandlers.handle_http_exception(request, exc)
|
||||
|
||||
# 获取请求信息用于上下文
|
||||
request_info = {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"client_ip": (
|
||||
getattr(request.client, "host", "unknown")
|
||||
if hasattr(request, "client")
|
||||
else "unknown"
|
||||
),
|
||||
"user_agent": request.headers.get("user-agent", "unknown"),
|
||||
}
|
||||
|
||||
# 使用韧性管理器处理错误
|
||||
rm = get_resilience_manager()
|
||||
if rm:
|
||||
try:
|
||||
error_result = rm.handle_error(
|
||||
error=exc,
|
||||
context=request_info,
|
||||
operation=f"{request.method} {request.url.path}",
|
||||
)
|
||||
|
||||
# 根据错误处理结果返回适当的响应
|
||||
if error_result.get("severity") and error_result["severity"].value == "critical":
|
||||
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
elif error_result.get("severity") and error_result["severity"].value == "high":
|
||||
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
else:
|
||||
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
return ErrorResponse.create(
|
||||
status_code=status_code,
|
||||
error_type="system_error",
|
||||
message=error_result.get("user_message", "系统遇到未知错误"),
|
||||
details={
|
||||
"error_id": error_result.get("error_id"),
|
||||
"recovery_info": "请稍后重试,如问题持续请联系管理员",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as resilience_error:
|
||||
# 如果韧性管理器本身出错,降级到基本处理
|
||||
logger.exception("韧性管理器处理异常时出错")
|
||||
|
||||
# 降级处理:基本的异常响应
|
||||
return ErrorResponse.from_exception(exc)
|
||||
247
src/core/key_capabilities.py
Normal file
247
src/core/key_capabilities.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Key 能力系统
|
||||
|
||||
能力类型:
|
||||
1. 互斥能力 (EXCLUSIVE): 需要时选有的,不需要时选没有的(如 cache_1h)
|
||||
2. 兼容能力 (COMPATIBLE): 需要时选有的,不需要时都可选(如 context_1m)
|
||||
|
||||
配置模式:
|
||||
1. user_configurable: 用户可配置(模型级 + Key级强制)
|
||||
2. auto_detect: 自动检测(请求失败后升级)
|
||||
3. request_param: 从请求参数检测
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class CapabilityMatchMode(Enum):
|
||||
"""能力匹配模式"""
|
||||
|
||||
EXCLUSIVE = "exclusive" # 互斥:需要时选有的,不需要时选没有的
|
||||
COMPATIBLE = "compatible" # 兼容:需要时选有的,不需要时都可选
|
||||
|
||||
|
||||
class CapabilityConfigMode(Enum):
|
||||
"""能力配置模式"""
|
||||
|
||||
USER_CONFIGURABLE = "user_configurable" # 用户可配置(模型级 + Key级强制)
|
||||
AUTO_DETECT = "auto_detect" # 自动检测(请求失败后升级)
|
||||
REQUEST_PARAM = "request_param" # 从请求参数检测
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityDefinition:
|
||||
"""能力定义"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
match_mode: CapabilityMatchMode
|
||||
config_mode: CapabilityConfigMode
|
||||
short_name: str = "" # 简短展示名称(用于列表等紧凑场景)
|
||||
error_patterns: List[str] = field(default_factory=list) # 错误检测关键词组
|
||||
|
||||
|
||||
# ============ 能力注册表 ============
|
||||
|
||||
_capabilities: Dict[str, CapabilityDefinition] = {}
|
||||
|
||||
|
||||
def register_capability(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str,
|
||||
match_mode: CapabilityMatchMode,
|
||||
config_mode: CapabilityConfigMode,
|
||||
short_name: str = "",
|
||||
error_patterns: Optional[List[str]] = None,
|
||||
) -> CapabilityDefinition:
|
||||
"""注册能力"""
|
||||
cap = CapabilityDefinition(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
match_mode=match_mode,
|
||||
config_mode=config_mode,
|
||||
short_name=short_name or display_name, # 默认使用 display_name
|
||||
error_patterns=error_patterns or [],
|
||||
)
|
||||
_capabilities[name] = cap
|
||||
return cap
|
||||
|
||||
|
||||
def get_capability(name: str) -> Optional[CapabilityDefinition]:
|
||||
"""获取能力定义"""
|
||||
return _capabilities.get(name)
|
||||
|
||||
|
||||
def get_all_capabilities() -> List[CapabilityDefinition]:
|
||||
"""获取所有能力定义"""
|
||||
return list(_capabilities.values())
|
||||
|
||||
|
||||
def get_user_configurable_capabilities() -> List[CapabilityDefinition]:
|
||||
"""获取用户可配置的能力列表"""
|
||||
return [c for c in _capabilities.values() if c.config_mode == CapabilityConfigMode.USER_CONFIGURABLE]
|
||||
|
||||
|
||||
# ============ 能力匹配检查 ============
|
||||
|
||||
|
||||
def check_capability_match(
|
||||
key_capabilities: Optional[Dict[str, bool]],
|
||||
requirements: Optional[Dict[str, bool]],
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检查 Key 能力是否满足需求
|
||||
|
||||
匹配逻辑:
|
||||
1. EXCLUSIVE(互斥)能力:
|
||||
- 请求需要且 Key 有 → 通过
|
||||
- 请求需要但 Key 没有 → 拒绝
|
||||
- 请求不需要但 Key 有 → 拒绝(避免浪费高价资源)
|
||||
- 请求不需要且 Key 没有 → 通过
|
||||
- 请求未声明但 Key 有 → 拒绝(关键:未声明等同于不需要)
|
||||
|
||||
2. COMPATIBLE(兼容)能力:
|
||||
- 请求需要且 Key 有 → 通过
|
||||
- 请求需要但 Key 没有 → 拒绝
|
||||
- 请求不需要/未声明且 Key 有 → 通过(无额外成本,不浪费)
|
||||
- 请求不需要/未声明且 Key 没有 → 通过
|
||||
|
||||
Args:
|
||||
key_capabilities: Key 拥有的能力 {"cache_1h": True, ...}
|
||||
requirements: 请求需要的能力 {"cache_1h": True, "context_1m": False}
|
||||
|
||||
Returns:
|
||||
(is_match, skip_reason) - 是否匹配及跳过原因
|
||||
"""
|
||||
key_caps = key_capabilities or {}
|
||||
reqs = requirements or {}
|
||||
|
||||
# 第一步:检查请求声明的需求
|
||||
for cap_name, is_required in reqs.items():
|
||||
cap_def = _capabilities.get(cap_name)
|
||||
if not cap_def:
|
||||
continue
|
||||
|
||||
key_has_cap = key_caps.get(cap_name, False)
|
||||
|
||||
if cap_def.match_mode == CapabilityMatchMode.EXCLUSIVE:
|
||||
if is_required and not key_has_cap:
|
||||
return False, f"需要{cap_def.display_name}但 Key 不支持"
|
||||
if not is_required and key_has_cap:
|
||||
return False, f"不需要{cap_def.display_name}(避免浪费高价资源)"
|
||||
|
||||
elif cap_def.match_mode == CapabilityMatchMode.COMPATIBLE:
|
||||
if is_required and not key_has_cap:
|
||||
return False, f"需要{cap_def.display_name}但 Key 不支持"
|
||||
|
||||
# 第二步:检查 Key 拥有的 EXCLUSIVE 能力是否被请求需要
|
||||
# 如果 Key 有某个 EXCLUSIVE 能力,但请求没有声明需要,应该跳过这个 Key
|
||||
for cap_name, key_has_cap in key_caps.items():
|
||||
if not key_has_cap:
|
||||
continue
|
||||
|
||||
cap_def = _capabilities.get(cap_name)
|
||||
if not cap_def:
|
||||
continue
|
||||
|
||||
if cap_def.match_mode == CapabilityMatchMode.EXCLUSIVE:
|
||||
# 如果请求没有声明需要这个 EXCLUSIVE 能力,视为不需要
|
||||
if cap_name not in reqs:
|
||||
return False, f"不需要{cap_def.display_name}(避免浪费高价资源)"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def _match_error_patterns(error_msg: str, patterns: List[str]) -> bool:
|
||||
"""检查错误信息是否匹配模式(所有关键词都要出现)"""
|
||||
if not patterns:
|
||||
return False
|
||||
msg_lower = error_msg.lower()
|
||||
return all(p.lower() in msg_lower for p in patterns)
|
||||
|
||||
|
||||
def detect_capability_upgrade_from_error(
|
||||
error_msg: str,
|
||||
current_requirements: Optional[Dict[str, bool]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从错误信息检测是否需要升级某能力
|
||||
|
||||
Args:
|
||||
error_msg: 错误信息
|
||||
current_requirements: 当前已有的能力需求
|
||||
|
||||
Returns:
|
||||
需要升级的能力名称,如果不需要升级则返回 None
|
||||
"""
|
||||
current_reqs = current_requirements or {}
|
||||
|
||||
for cap in _capabilities.values():
|
||||
if not current_reqs.get(cap.name) and cap.error_patterns:
|
||||
if _match_error_patterns(error_msg, cap.error_patterns):
|
||||
return cap.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============ 兼容性别名 ============
|
||||
|
||||
# 保留旧 API 兼容
|
||||
get_capability_definition = get_capability
|
||||
|
||||
|
||||
class _CapabilityDefinitionsProxy:
|
||||
"""CAPABILITY_DEFINITIONS 代理,提供字典式访问(兼容旧代码)"""
|
||||
|
||||
def get(self, name: str) -> Optional[CapabilityDefinition]:
|
||||
return _capabilities.get(name)
|
||||
|
||||
def __getitem__(self, name: str) -> CapabilityDefinition:
|
||||
result = _capabilities.get(name)
|
||||
if result is None:
|
||||
raise KeyError(name)
|
||||
return result
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in _capabilities
|
||||
|
||||
def values(self) -> List[CapabilityDefinition]:
|
||||
return list(_capabilities.values())
|
||||
|
||||
def items(self) -> List[Tuple[str, CapabilityDefinition]]:
|
||||
return list(_capabilities.items())
|
||||
|
||||
|
||||
CAPABILITY_DEFINITIONS = _CapabilityDefinitionsProxy()
|
||||
|
||||
|
||||
# ============ 兼容旧的插件基类(逐步废弃) ============
|
||||
|
||||
CapabilityPlugin = CapabilityDefinition # 类型别名,兼容旧代码
|
||||
|
||||
|
||||
# ============ 注册内置能力 ============
|
||||
|
||||
register_capability(
|
||||
name="cache_1h",
|
||||
display_name="1 小时缓存",
|
||||
description="使用 1 小时缓存 TTL(价格更高,适合长对话)",
|
||||
match_mode=CapabilityMatchMode.EXCLUSIVE,
|
||||
config_mode=CapabilityConfigMode.USER_CONFIGURABLE,
|
||||
short_name="1h缓存",
|
||||
)
|
||||
|
||||
register_capability(
|
||||
name="context_1m",
|
||||
display_name="CLI 1M 上下文",
|
||||
description="支持 1M tokens 上下文窗口",
|
||||
match_mode=CapabilityMatchMode.COMPATIBLE,
|
||||
config_mode=CapabilityConfigMode.REQUEST_PARAM,
|
||||
short_name="CLI 1M",
|
||||
error_patterns=["context", "token", "length", "exceed"], # 上下文超限错误
|
||||
)
|
||||
135
src/core/logger.py
Normal file
135
src/core/logger.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
统一日志系统 - 基于 loguru
|
||||
|
||||
日志级别策略:
|
||||
- DEBUG: 开发调试,详细执行流程、变量值、缓存操作
|
||||
- INFO: 生产环境,关键业务操作、状态变更、请求处理
|
||||
- WARNING: 潜在问题、降级处理、资源警告
|
||||
- ERROR: 异常错误、需要关注的故障
|
||||
|
||||
输出策略:
|
||||
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
|
||||
- 文件: 始终保存 DEBUG 级别,保留30天,每日轮转
|
||||
|
||||
使用方式:
|
||||
from src.core.logger import logger
|
||||
|
||||
logger.info("消息")
|
||||
logger.debug("调试信息")
|
||||
logger.warning("警告")
|
||||
logger.error("错误")
|
||||
logger.exception("异常,带堆栈")
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# ============================================================================
|
||||
# 环境检测
|
||||
# ============================================================================
|
||||
|
||||
IS_DOCKER = (
|
||||
os.path.exists("/.dockerenv")
|
||||
or os.environ.get("DOCKER_CONTAINER", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# 日志级别: 默认开发环境 DEBUG, 生产环境 INFO
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG" if not IS_DOCKER else "INFO").upper()
|
||||
|
||||
# 是否禁用文件日志 (用于测试或特殊场景)
|
||||
DISABLE_FILE_LOG = os.getenv("LOG_DISABLE_FILE", "false").lower() == "true"
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
# ============================================================================
|
||||
# 日志格式定义
|
||||
# ============================================================================
|
||||
|
||||
CONSOLE_FORMAT_DEV = (
|
||||
"<green>{time:HH:mm:ss}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{message}</cyan>"
|
||||
)
|
||||
|
||||
CONSOLE_FORMAT_PROD = "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}"
|
||||
|
||||
FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"
|
||||
|
||||
# ============================================================================
|
||||
# 日志配置
|
||||
# ============================================================================
|
||||
|
||||
logger.remove()
|
||||
|
||||
|
||||
def _log_filter(record):
|
||||
return "watchfiles" not in record["name"]
|
||||
|
||||
|
||||
if IS_DOCKER:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_PROD,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
colorize=False,
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_DEV,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
colorize=True,
|
||||
)
|
||||
|
||||
if not DISABLE_FILE_LOG:
|
||||
log_dir = PROJECT_ROOT / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 主日志文件 - 所有级别
|
||||
logger.add(
|
||||
log_dir / "app.log",
|
||||
format=FILE_FORMAT,
|
||||
level="DEBUG",
|
||||
filter=_log_filter,
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# 错误日志文件 - 仅 ERROR 及以上
|
||||
logger.add(
|
||||
log_dir / "error.log",
|
||||
format=FILE_FORMAT,
|
||||
level="ERROR",
|
||||
filter=_log_filter,
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# 禁用第三方库噪音日志
|
||||
# ============================================================================
|
||||
|
||||
logging.getLogger("watchfiles").setLevel(logging.ERROR)
|
||||
logging.getLogger("watchfiles.main").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
# ============================================================================
|
||||
# 导出
|
||||
# ============================================================================
|
||||
|
||||
__all__ = ["logger"]
|
||||
46
src/core/metrics.py
Normal file
46
src/core/metrics.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Prometheus metrics for monitoring
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
# 并发槽位占用时长分布
|
||||
concurrency_slot_duration_seconds = Histogram(
|
||||
"concurrency_slot_duration_seconds",
|
||||
"Duration of concurrency slot occupation in seconds",
|
||||
["key_id", "exception"],
|
||||
buckets=[0.1, 0.5, 1, 5, 10, 30, 60, 120, 300, 600], # 0.1s 到 10 分钟
|
||||
)
|
||||
|
||||
# 并发槽位释放计数
|
||||
concurrency_slot_release_total = Counter(
|
||||
"concurrency_slot_release_total",
|
||||
"Total number of concurrency slot releases",
|
||||
["key_id", "exception"],
|
||||
)
|
||||
|
||||
# 当前并发槽位使用数
|
||||
concurrency_slots_in_use = Gauge(
|
||||
"concurrency_slots_in_use", "Current number of concurrency slots in use", ["key_id"]
|
||||
)
|
||||
|
||||
# 流式请求时长分布
|
||||
streaming_request_duration_seconds = Histogram(
|
||||
"streaming_request_duration_seconds",
|
||||
"Duration of streaming requests in seconds",
|
||||
["key_id", "status"],
|
||||
buckets=[1, 5, 10, 30, 60, 120, 300, 600, 1800], # 1s 到 30 分钟
|
||||
)
|
||||
|
||||
# 请求总数(按类型)
|
||||
request_total = Counter(
|
||||
"request_total",
|
||||
"Total number of requests",
|
||||
["type", "status"], # type values: streaming/non-streaming, status: success/error
|
||||
)
|
||||
|
||||
# 健康监控相关
|
||||
health_open_circuits = Gauge(
|
||||
"health_open_circuits",
|
||||
"Number of provider keys currently in circuit breaker open state",
|
||||
)
|
||||
106
src/core/optimization_utils.py
Normal file
106
src/core/optimization_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
优化工具类 - 包含Token计数和响应头管理
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""
|
||||
改进的Token计数器
|
||||
支持多种模型的准确计数
|
||||
"""
|
||||
|
||||
# 模型到编码器的映射
|
||||
MODEL_TO_ENCODING = {
|
||||
"gpt-4": "cl100k_base",
|
||||
"gpt-3.5-turbo": "cl100k_base",
|
||||
"claude-3": "cl100k_base", # Claude使用类似的tokenizer
|
||||
"claude-2": "cl100k_base",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._encodings = {}
|
||||
self._default_encoding = None
|
||||
|
||||
def _get_encoding(self, model: str):
|
||||
"""获取模型对应的编码器"""
|
||||
# 标准化模型名称
|
||||
model_base = model.lower().split("-")[0]
|
||||
|
||||
if model_base not in self._encodings:
|
||||
encoding_name = self.MODEL_TO_ENCODING.get(model_base, "cl100k_base") # 默认编码器
|
||||
try:
|
||||
self._encodings[model_base] = tiktoken.get_encoding(encoding_name)
|
||||
except Exception:
|
||||
# 如果失败,使用默认编码器
|
||||
if not self._default_encoding:
|
||||
self._default_encoding = tiktoken.get_encoding("cl100k_base")
|
||||
self._encodings[model_base] = self._default_encoding
|
||||
|
||||
return self._encodings[model_base]
|
||||
|
||||
def count_tokens(self, text: str, model: str = "claude-3") -> int:
|
||||
"""
|
||||
精确计算文本的token数量
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
try:
|
||||
encoding = self._get_encoding(model)
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# 降级到简单估算
|
||||
return len(text) // 4
|
||||
|
||||
def count_messages_tokens(self, messages: list, model: str = "claude-3") -> int:
|
||||
"""
|
||||
计算消息列表的总token数
|
||||
"""
|
||||
total = 0
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
# 计算角色标记
|
||||
total += 4 # 角色和分隔符的开销
|
||||
|
||||
# 计算内容
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += self.count_tokens(content, model)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
total += self.count_tokens(item["text"], model)
|
||||
|
||||
return total
|
||||
|
||||
def estimate_response_tokens(self, response: Any, model: str = "claude-3") -> int:
|
||||
"""
|
||||
估算响应的token数量
|
||||
"""
|
||||
if isinstance(response, dict):
|
||||
# 尝试从响应中提取内容
|
||||
if "content" in response:
|
||||
content = response["content"]
|
||||
if isinstance(content, list):
|
||||
text = " ".join(
|
||||
item.get("text", "") for item in content if isinstance(item, dict)
|
||||
)
|
||||
else:
|
||||
text = str(content)
|
||||
return self.count_tokens(text, model)
|
||||
elif "choices" in response:
|
||||
# OpenAI格式
|
||||
total = 0
|
||||
for choice in response.get("choices", []):
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
total += self.count_tokens(content, model)
|
||||
return total
|
||||
|
||||
# 降级到简单估算
|
||||
return len(str(response)) // 4
|
||||
205
src/core/provider_health.py
Normal file
205
src/core/provider_health.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
提供商健康度管理
|
||||
基于简单的失败计数和优先级调整
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class ProviderHealthTracker:
|
||||
"""
|
||||
追踪提供商的健康状态
|
||||
根据失败率动态调整优先级
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_window: int = 300, # 5分钟时间窗口
|
||||
failure_threshold: int = 3, # 3次失败降低优先级
|
||||
recovery_time: int = 600, # 10分钟后重置
|
||||
):
|
||||
self.failure_window = failure_window
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_time = recovery_time
|
||||
|
||||
# 存储每个提供商的失败记录
|
||||
self.failures: Dict[str, list] = defaultdict(list)
|
||||
# 存储每个提供商的成功记录
|
||||
self.successes: Dict[str, list] = defaultdict(list)
|
||||
# 存储优先级调整
|
||||
self.priority_adjustments: Dict[str, int] = {}
|
||||
|
||||
def record_success(self, provider_name: str):
|
||||
"""记录成功的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
# 记录成功时间
|
||||
self.successes[provider_name].append(current_time)
|
||||
|
||||
# 清理旧记录
|
||||
self._cleanup_old_records(provider_name, current_time)
|
||||
|
||||
# 如果连续成功,可以恢复优先级
|
||||
if len(self.successes[provider_name]) >= 5:
|
||||
if self.priority_adjustments.get(provider_name, 0) < 0:
|
||||
self.priority_adjustments[provider_name] += 1
|
||||
|
||||
def record_failure(self, provider_name: str):
|
||||
"""记录失败的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
# 记录失败时间
|
||||
self.failures[provider_name].append(current_time)
|
||||
|
||||
# 清理旧记录
|
||||
self._cleanup_old_records(provider_name, current_time)
|
||||
|
||||
# 检查是否需要降低优先级
|
||||
recent_failures = len(self.failures[provider_name])
|
||||
if recent_failures >= self.failure_threshold:
|
||||
# 降低优先级
|
||||
current_adjustment = self.priority_adjustments.get(provider_name, 0)
|
||||
self.priority_adjustments[provider_name] = current_adjustment - 1
|
||||
|
||||
def get_priority_adjustment(self, provider_name: str) -> int:
|
||||
"""
|
||||
获取优先级调整值
|
||||
负数表示降低优先级,正数表示提高优先级
|
||||
"""
|
||||
return self.priority_adjustments.get(provider_name, 0)
|
||||
|
||||
def get_health_status(self, provider_name: str) -> Dict:
|
||||
"""
|
||||
获取提供商的健康状态
|
||||
"""
|
||||
current_time = time.time()
|
||||
self._cleanup_old_records(provider_name, current_time)
|
||||
|
||||
recent_failures = len(self.failures[provider_name])
|
||||
recent_successes = len(self.successes[provider_name])
|
||||
total_requests = recent_failures + recent_successes
|
||||
|
||||
failure_rate = recent_failures / total_requests if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"provider": provider_name,
|
||||
"recent_failures": recent_failures,
|
||||
"recent_successes": recent_successes,
|
||||
"failure_rate": failure_rate,
|
||||
"priority_adjustment": self.get_priority_adjustment(provider_name),
|
||||
"status": self._get_status_label(failure_rate, recent_failures),
|
||||
}
|
||||
|
||||
def _cleanup_old_records(self, provider_name: str, current_time: float):
|
||||
"""清理超出时间窗口的记录"""
|
||||
# 清理失败记录
|
||||
self.failures[provider_name] = [
|
||||
t for t in self.failures[provider_name] if current_time - t < self.failure_window
|
||||
]
|
||||
|
||||
# 清理成功记录
|
||||
self.successes[provider_name] = [
|
||||
t for t in self.successes[provider_name] if current_time - t < self.failure_window
|
||||
]
|
||||
|
||||
# 如果很久没有失败,重置优先级调整
|
||||
if not self.failures[provider_name] and self.priority_adjustments.get(provider_name, 0) < 0:
|
||||
# 检查恢复时间
|
||||
if all(current_time - t > self.recovery_time for t in self.successes[provider_name]):
|
||||
self.priority_adjustments[provider_name] = 0
|
||||
|
||||
def _get_status_label(self, failure_rate: float, recent_failures: int) -> str:
|
||||
"""根据失败率返回状态标签"""
|
||||
if recent_failures >= self.failure_threshold:
|
||||
return "degraded" # 降级
|
||||
elif failure_rate > 0.5:
|
||||
return "unstable" # 不稳定
|
||||
elif failure_rate > 0.1:
|
||||
return "warning" # 警告
|
||||
else:
|
||||
return "healthy" # 健康
|
||||
|
||||
def should_use_provider(self, provider_name: str) -> bool:
|
||||
"""
|
||||
判断是否应该使用该提供商
|
||||
简单的策略:如果优先级调整低于-3,暂时不使用
|
||||
"""
|
||||
adjustment = self.get_priority_adjustment(provider_name)
|
||||
return adjustment > -3
|
||||
|
||||
def reset_provider_health(self, provider_name: str):
|
||||
"""重置提供商的健康状态(管理员手动操作)"""
|
||||
self.failures[provider_name] = []
|
||||
self.successes[provider_name] = []
|
||||
self.priority_adjustments[provider_name] = 0
|
||||
|
||||
|
||||
class SimpleProviderSelector:
|
||||
"""
|
||||
简单的提供商选择器
|
||||
基于优先级和健康状态
|
||||
"""
|
||||
|
||||
def __init__(self, health_tracker: ProviderHealthTracker):
|
||||
self.health_tracker = health_tracker
|
||||
|
||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None):
|
||||
"""
|
||||
选择提供商
|
||||
|
||||
Args:
|
||||
providers: 可用提供商列表(已按基础优先级排序)
|
||||
specified_provider: 用户指定的提供商
|
||||
|
||||
Returns:
|
||||
选中的提供商
|
||||
"""
|
||||
# 如果用户指定了提供商,直接使用(不管健康状态)
|
||||
if specified_provider:
|
||||
return next((p for p in providers if p.name == specified_provider), None)
|
||||
|
||||
# 否则,根据优先级和健康状态选择
|
||||
# 对提供商列表进行动态排序
|
||||
sorted_providers = sorted(
|
||||
providers,
|
||||
key=lambda p: (
|
||||
p.priority + self.health_tracker.get_priority_adjustment(p.name),
|
||||
-p.id, # 相同优先级时,使用ID作为次要排序
|
||||
),
|
||||
reverse=True, # 优先级高的在前
|
||||
)
|
||||
|
||||
# 选择第一个健康的提供商
|
||||
for provider in sorted_providers:
|
||||
if self.health_tracker.should_use_provider(provider.name):
|
||||
return provider
|
||||
|
||||
# 如果都不健康,还是返回第一个(降级策略)
|
||||
return sorted_providers[0] if sorted_providers else None
|
||||
|
||||
def get_provider_rankings(self, providers: list) -> list:
|
||||
"""
|
||||
获取提供商的当前排名(用于调试和监控)
|
||||
"""
|
||||
rankings = []
|
||||
for provider in providers:
|
||||
health_status = self.health_tracker.get_health_status(provider.name)
|
||||
effective_priority = provider.priority + health_status["priority_adjustment"]
|
||||
|
||||
rankings.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"base_priority": provider.priority,
|
||||
"adjustment": health_status["priority_adjustment"],
|
||||
"effective_priority": effective_priority,
|
||||
"status": health_status["status"],
|
||||
"failure_rate": health_status["failure_rate"],
|
||||
}
|
||||
)
|
||||
|
||||
# 按有效优先级排序
|
||||
rankings.sort(key=lambda x: x["effective_priority"], reverse=True)
|
||||
return rankings
|
||||
428
src/core/resilience.py
Normal file
428
src/core/resilience.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
系统韧性和风险管控模块
|
||||
提供全局的错误处理、自动恢复、降级策略和用户友好的错误体验
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
from ..core.exceptions import ProxyException
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class ErrorSeverity(Enum):
|
||||
"""错误严重程度"""
|
||||
|
||||
LOW = "low" # 低级错误,不影响核心功能
|
||||
MEDIUM = "medium" # 中级错误,影响部分功能
|
||||
HIGH = "high" # 高级错误,影响主要功能
|
||||
CRITICAL = "critical" # 严重错误,影响系统可用性
|
||||
|
||||
|
||||
class RecoveryStrategy(Enum):
|
||||
"""恢复策略"""
|
||||
|
||||
RETRY = "retry" # 重试
|
||||
FALLBACK = "fallback" # 降级
|
||||
CIRCUIT_BREAKER = "circuit_breaker" # 熔断
|
||||
GRACEFUL_DEGRADE = "graceful_degrade" # 优雅降级
|
||||
USER_NOTIFY = "user_notify" # 通知用户
|
||||
|
||||
|
||||
class ErrorPattern:
|
||||
"""错误模式定义"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_types: List[Type[Exception]],
|
||||
severity: ErrorSeverity,
|
||||
recovery_strategy: RecoveryStrategy,
|
||||
user_message: str,
|
||||
auto_recover: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
circuit_threshold: int = 5,
|
||||
):
|
||||
self.error_types = error_types
|
||||
self.severity = severity
|
||||
self.recovery_strategy = recovery_strategy
|
||||
self.user_message = user_message
|
||||
self.auto_recover = auto_recover
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.circuit_threshold = circuit_threshold
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""熔断器"""
|
||||
|
||||
def __init__(self, failure_threshold: int = 5, timeout: int = 60):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.timeout = timeout
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
self.state = "closed" # closed, open, half-open
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs):
|
||||
"""执行函数调用,应用熔断逻辑"""
|
||||
with self._lock:
|
||||
if self.state == "open":
|
||||
if self._should_attempt_reset():
|
||||
self.state = "half-open"
|
||||
else:
|
||||
raise Exception("服务暂时不可用,请稍后重试")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
except Exception as e:
|
||||
self._on_failure()
|
||||
raise
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""检查是否应该尝试重置熔断器"""
|
||||
if self.last_failure_time is None:
|
||||
return True
|
||||
return time.time() - self.last_failure_time >= self.timeout
|
||||
|
||||
def _on_success(self):
|
||||
"""成功时重置计数器"""
|
||||
self.failure_count = 0
|
||||
self.state = "closed"
|
||||
|
||||
def _on_failure(self):
|
||||
"""失败时增加计数器"""
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
self.state = "open"
|
||||
|
||||
|
||||
class ResilienceManager:
|
||||
"""系统韧性管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.error_patterns: List[ErrorPattern] = []
|
||||
self.circuit_breakers: Dict[str, CircuitBreaker] = {}
|
||||
self.error_stats: Dict[str, int] = {}
|
||||
self.last_errors: List[Dict[str, Any]] = []
|
||||
self._setup_default_patterns()
|
||||
|
||||
def _setup_default_patterns(self):
|
||||
"""设置默认错误处理模式"""
|
||||
|
||||
# 数据库连接错误 - 只捕获特定的数据库相关异常
|
||||
try:
|
||||
from sqlalchemy.exc import (
|
||||
DatabaseError,
|
||||
DisconnectionError,
|
||||
OperationalError,
|
||||
StatementError,
|
||||
)
|
||||
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
||||
|
||||
db_exceptions = [
|
||||
OperationalError,
|
||||
DisconnectionError,
|
||||
SQLTimeoutError,
|
||||
StatementError,
|
||||
DatabaseError,
|
||||
]
|
||||
except ImportError:
|
||||
# 如果SQLAlchemy不可用,使用通用异常类型
|
||||
db_exceptions = [ConnectionError, OSError]
|
||||
|
||||
self.add_error_pattern(
|
||||
ErrorPattern(
|
||||
error_types=db_exceptions,
|
||||
severity=ErrorSeverity.HIGH,
|
||||
recovery_strategy=RecoveryStrategy.RETRY,
|
||||
user_message="数据库连接异常,正在重试...",
|
||||
max_retries=3,
|
||||
retry_delay=1.0,
|
||||
)
|
||||
)
|
||||
|
||||
# 认证相关错误 - 只捕获特定的认证异常
|
||||
try:
|
||||
from ..core.exceptions import ForbiddenException, ProviderAuthException
|
||||
|
||||
auth_exceptions = [ProviderAuthException, ForbiddenException]
|
||||
except ImportError:
|
||||
# 如果无法导入特定异常,使用更保守的方式(不使用通用异常)
|
||||
auth_exceptions = []
|
||||
|
||||
if auth_exceptions:
|
||||
self.add_error_pattern(
|
||||
ErrorPattern(
|
||||
error_types=auth_exceptions,
|
||||
severity=ErrorSeverity.MEDIUM,
|
||||
recovery_strategy=RecoveryStrategy.USER_NOTIFY,
|
||||
user_message="认证失败,请检查API密钥或重新登录",
|
||||
auto_recover=False,
|
||||
)
|
||||
)
|
||||
|
||||
# 网络请求错误
|
||||
self.add_error_pattern(
|
||||
ErrorPattern(
|
||||
error_types=[ConnectionError, TimeoutError],
|
||||
severity=ErrorSeverity.MEDIUM,
|
||||
recovery_strategy=RecoveryStrategy.FALLBACK,
|
||||
user_message="网络连接异常,正在尝试备用方案...",
|
||||
max_retries=2,
|
||||
)
|
||||
)
|
||||
|
||||
def add_error_pattern(self, pattern: ErrorPattern):
|
||||
"""添加错误处理模式"""
|
||||
self.error_patterns.append(pattern)
|
||||
|
||||
def get_circuit_breaker(self, key: str) -> CircuitBreaker:
|
||||
"""获取或创建熔断器"""
|
||||
if key not in self.circuit_breakers:
|
||||
self.circuit_breakers[key] = CircuitBreaker()
|
||||
return self.circuit_breakers[key]
|
||||
|
||||
def handle_error(
|
||||
self, error: Exception, context: Dict[str, Any] = None, operation: str = "unknown"
|
||||
) -> Dict[str, Any]:
|
||||
"""处理错误并返回处理结果"""
|
||||
|
||||
error_id = str(uuid.uuid4())[:8]
|
||||
context = context or {}
|
||||
|
||||
# 记录错误
|
||||
error_info = {
|
||||
"error_id": error_id,
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"operation": operation,
|
||||
"context": context,
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
|
||||
self.last_errors.append(error_info)
|
||||
# 只保留最近100个错误
|
||||
if len(self.last_errors) > 100:
|
||||
self.last_errors.pop(0)
|
||||
|
||||
# 更新错误统计
|
||||
error_key = f"{type(error).__name__}:{operation}"
|
||||
self.error_stats[error_key] = self.error_stats.get(error_key, 0) + 1
|
||||
|
||||
# 查找匹配的错误处理模式
|
||||
pattern = self._find_matching_pattern(error)
|
||||
|
||||
if pattern:
|
||||
logger.error(f"错误处理 [{error_id}]: {pattern.user_message}")
|
||||
|
||||
return {
|
||||
"error_id": error_id,
|
||||
"severity": pattern.severity,
|
||||
"recovery_strategy": pattern.recovery_strategy,
|
||||
"user_message": pattern.user_message,
|
||||
"auto_recover": pattern.auto_recover,
|
||||
"pattern": pattern,
|
||||
}
|
||||
else:
|
||||
# 未匹配的错误,使用默认处理
|
||||
logger.error(f"未知错误 [{error_id}]: {str(error)}")
|
||||
|
||||
return {
|
||||
"error_id": error_id,
|
||||
"severity": ErrorSeverity.MEDIUM,
|
||||
"recovery_strategy": RecoveryStrategy.USER_NOTIFY,
|
||||
"user_message": "系统遇到未知错误,请稍后重试或联系管理员",
|
||||
"auto_recover": False,
|
||||
"pattern": None,
|
||||
}
|
||||
|
||||
def _find_matching_pattern(self, error: Exception) -> Optional[ErrorPattern]:
|
||||
"""查找匹配的错误处理模式"""
|
||||
for pattern in self.error_patterns:
|
||||
if any(isinstance(error, error_type) for error_type in pattern.error_types):
|
||||
return pattern
|
||||
return None
|
||||
|
||||
def get_error_stats(self) -> Dict[str, Any]:
|
||||
"""获取错误统计"""
|
||||
return {
|
||||
"total_errors": sum(self.error_stats.values()),
|
||||
"error_breakdown": self.error_stats.copy(),
|
||||
"recent_errors": len(self.last_errors),
|
||||
"circuit_breakers": {
|
||||
key: {"state": cb.state, "failure_count": cb.failure_count}
|
||||
for key, cb in self.circuit_breakers.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 全局韧性管理器实例
|
||||
resilience_manager = ResilienceManager()
|
||||
|
||||
|
||||
def resilient_operation(
|
||||
operation_name: str = None,
|
||||
max_retries: int = None,
|
||||
retry_delay: float = None,
|
||||
circuit_breaker_key: str = None,
|
||||
context: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
韧性操作装饰器
|
||||
自动处理重试、熔断、错误记录等
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
op_name = operation_name or f"{func.__module__}.{func.__name__}"
|
||||
retries = max_retries or 3
|
||||
delay = retry_delay or 1.0
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
# 如果指定了熔断器,使用熔断逻辑
|
||||
if circuit_breaker_key:
|
||||
cb = resilience_manager.get_circuit_breaker(circuit_breaker_key)
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await cb.call(func, *args, **kwargs)
|
||||
else:
|
||||
return cb.call(func, *args, **kwargs)
|
||||
else:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# 处理错误
|
||||
error_result = resilience_manager.handle_error(
|
||||
error=e,
|
||||
context={**(context or {}), "attempt": attempt + 1, "max_retries": retries},
|
||||
operation=op_name,
|
||||
)
|
||||
|
||||
# 如果是最后一次尝试,或者不应该自动恢复,直接抛出
|
||||
if attempt == retries or not error_result.get("auto_recover", True):
|
||||
raise ProxyException(
|
||||
status_code=500,
|
||||
error_type="system_error",
|
||||
message=error_result["user_message"],
|
||||
details={
|
||||
"error_id": error_result["error_id"],
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# 等待后重试
|
||||
if attempt < retries:
|
||||
await asyncio.sleep(delay * (attempt + 1)) # 指数退避
|
||||
|
||||
# 这里不应该到达,但作为安全网
|
||||
raise last_error
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 对于同步函数,创建异步包装器并运行
|
||||
return asyncio.run(async_wrapper(*args, **kwargs))
|
||||
|
||||
# 根据函数类型返回对应的包装器
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def safe_operation(operation_name: str, context: Dict[str, Any] = None):
|
||||
"""
|
||||
安全操作上下文管理器
|
||||
自动处理异常并提供用户友好的错误信息
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
error_result = resilience_manager.handle_error(
|
||||
error=e, context=context or {}, operation=operation_name
|
||||
)
|
||||
|
||||
# 根据错误严重程度决定是否抛出异常
|
||||
if error_result["severity"] in [ErrorSeverity.HIGH, ErrorSeverity.CRITICAL]:
|
||||
raise ProxyException(
|
||||
status_code=500,
|
||||
error_type="system_error",
|
||||
message=error_result["user_message"],
|
||||
details={"error_id": error_result["error_id"]},
|
||||
)
|
||||
else:
|
||||
# 记录警告但不中断操作
|
||||
logger.warning(f"操作警告 [{error_result['error_id']}]: {error_result['user_message']}")
|
||||
|
||||
|
||||
def graceful_degradation(fallback_func: Callable = None, fallback_value: Any = None):
|
||||
"""
|
||||
优雅降级装饰器
|
||||
当主要功能失败时,自动切换到备用方案
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"主要功能失败,启用降级模式: {func.__name__}")
|
||||
|
||||
if fallback_func:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(fallback_func):
|
||||
return await fallback_func(*args, **kwargs)
|
||||
else:
|
||||
return fallback_func(*args, **kwargs)
|
||||
except Exception as fallback_error:
|
||||
logger.exception(f"降级方案也失败了: {fallback_func.__name__}")
|
||||
raise e # 抛出原始错误
|
||||
else:
|
||||
return fallback_value
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return lambda *args, **kwargs: asyncio.run(async_wrapper(*args, **kwargs))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# 导出主要接口
|
||||
__all__ = [
|
||||
"resilience_manager",
|
||||
"resilient_operation",
|
||||
"safe_operation",
|
||||
"graceful_degradation",
|
||||
"ErrorSeverity",
|
||||
"RecoveryStrategy",
|
||||
"ErrorPattern",
|
||||
]
|
||||
181
src/core/validators.py
Normal file
181
src/core/validators.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
输入验证器
|
||||
包含密码复杂度验证和其他输入验证
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class PasswordValidator:
|
||||
"""密码复杂度验证器"""
|
||||
|
||||
MIN_LENGTH = 6 # 降低到6位
|
||||
MAX_LENGTH = 128
|
||||
|
||||
@classmethod
|
||||
def validate(cls, password: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证密码复杂度
|
||||
|
||||
要求:
|
||||
- 长度至少6个字符
|
||||
|
||||
Args:
|
||||
password: 待验证的密码
|
||||
|
||||
Returns:
|
||||
(是否通过, 错误消息)
|
||||
"""
|
||||
if not password:
|
||||
return False, "密码不能为空"
|
||||
|
||||
if len(password) < cls.MIN_LENGTH:
|
||||
return False, f"密码长度至少为{cls.MIN_LENGTH}个字符"
|
||||
|
||||
if len(password) > cls.MAX_LENGTH:
|
||||
return False, f"密码长度不能超过{cls.MAX_LENGTH}个字符"
|
||||
|
||||
# 简化密码复杂度要求 - 只检查长度
|
||||
# 不再要求大小写字母、数字和特殊字符
|
||||
|
||||
# 检查常见弱密码
|
||||
weak_passwords = [
|
||||
"password123",
|
||||
"admin123",
|
||||
"12345678",
|
||||
"qwerty123",
|
||||
"password@123",
|
||||
"admin@123",
|
||||
"Password123!",
|
||||
"Admin123!",
|
||||
]
|
||||
if password.lower() in [p.lower() for p in weak_passwords]:
|
||||
return False, "密码过于简单,请使用更复杂的密码"
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def get_password_strength(cls, password: str) -> str:
|
||||
"""
|
||||
获取密码强度评级
|
||||
|
||||
Args:
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
强度评级: 弱、中、强、非常强
|
||||
"""
|
||||
if not password:
|
||||
return "无效"
|
||||
|
||||
score = 0
|
||||
|
||||
# 长度评分
|
||||
if len(password) >= 8:
|
||||
score += 1
|
||||
if len(password) >= 12:
|
||||
score += 1
|
||||
if len(password) >= 16:
|
||||
score += 1
|
||||
|
||||
# 字符类型评分
|
||||
if re.search(r"[a-z]", password):
|
||||
score += 1
|
||||
if re.search(r"[A-Z]", password):
|
||||
score += 1
|
||||
if re.search(r"\d", password):
|
||||
score += 1
|
||||
if re.search(r'[!@#$%^&*()_+\-=\[\]{};:\'",.<>?/\\|`~]', password):
|
||||
score += 2
|
||||
|
||||
# 额外复杂度评分
|
||||
if re.search(r"[^\w\s]", password): # 非字母数字字符
|
||||
score += 1
|
||||
|
||||
if score < 3:
|
||||
return "弱"
|
||||
elif score < 5:
|
||||
return "中"
|
||||
elif score < 7:
|
||||
return "强"
|
||||
else:
|
||||
return "非常强"
|
||||
|
||||
|
||||
class EmailValidator:
|
||||
"""邮箱验证器"""
|
||||
|
||||
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
||||
|
||||
@classmethod
|
||||
def validate(cls, email: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证邮箱格式
|
||||
|
||||
Args:
|
||||
email: 待验证的邮箱
|
||||
|
||||
Returns:
|
||||
(是否通过, 错误消息)
|
||||
"""
|
||||
if not email:
|
||||
return False, "邮箱不能为空"
|
||||
|
||||
if len(email) > 255:
|
||||
return False, "邮箱长度不能超过255个字符"
|
||||
|
||||
if not cls.EMAIL_REGEX.match(email):
|
||||
return False, "邮箱格式不正确"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
class UsernameValidator:
|
||||
"""用户名验证器"""
|
||||
|
||||
MIN_LENGTH = 3
|
||||
MAX_LENGTH = 30
|
||||
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
@classmethod
|
||||
def validate(cls, username: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证用户名
|
||||
|
||||
Args:
|
||||
username: 待验证的用户名
|
||||
|
||||
Returns:
|
||||
(是否通过, 错误消息)
|
||||
"""
|
||||
if not username:
|
||||
return False, "用户名不能为空"
|
||||
|
||||
if len(username) < cls.MIN_LENGTH:
|
||||
return False, f"用户名长度至少为{cls.MIN_LENGTH}个字符"
|
||||
|
||||
if len(username) > cls.MAX_LENGTH:
|
||||
return False, f"用户名长度不能超过{cls.MAX_LENGTH}个字符"
|
||||
|
||||
if not cls.USERNAME_REGEX.match(username):
|
||||
return False, "用户名只能包含字母、数字、下划线和连字符"
|
||||
|
||||
# 检查保留用户名
|
||||
reserved_names = [
|
||||
"admin",
|
||||
"root",
|
||||
"system",
|
||||
"api",
|
||||
"test",
|
||||
"demo",
|
||||
"user",
|
||||
"guest",
|
||||
"bot",
|
||||
"webhook",
|
||||
"support",
|
||||
]
|
||||
if username.lower() in reserved_names:
|
||||
return False, "该用户名为系统保留用户名"
|
||||
|
||||
return True, None
|
||||
Reference in New Issue
Block a user