Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

3
src/plugins/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
插件系统基础模块
"""

View File

@@ -0,0 +1,8 @@
"""
认证插件模块
"""
from .api_key import ApiKeyAuthPlugin
from .base import AuthContext, AuthPlugin
__all__ = ["AuthPlugin", "AuthContext", "ApiKeyAuthPlugin"]

View File

@@ -0,0 +1,96 @@
"""
API Key认证插件
支持从header中提取API Key进行认证
"""
from typing import Optional
from fastapi import Request
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.services.auth.service import AuthService
from src.services.usage.service import UsageService
from .base import AuthContext, AuthPlugin
class ApiKeyAuthPlugin(AuthPlugin):
"""
API Key认证插件
支持从x-api-key header或Authorization Bearer token中提取API Key
"""
def __init__(self):
super().__init__(name="api_key", priority=10)
def get_credentials(self, request: Request) -> Optional[str]:
"""
从请求头中提取API Key
支持两种方式:
1. x-api-key: <key>
2. Authorization: Bearer <key>
"""
# 尝试从x-api-key header获取
api_key = request.headers.get("x-api-key")
if api_key:
return api_key
# 尝试从Authorization header获取
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.replace("Bearer ", "")
return None
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
"""
使用API Key进行认证
"""
# 提取API Key
api_key = self.get_credentials(request)
if not api_key:
logger.debug("未找到API Key凭据")
return None
# 认证API Key
auth_result = AuthService.authenticate_api_key(db, api_key)
if not auth_result:
logger.warning("API Key认证失败")
return None
user, api_key_obj = auth_result
# 检查用户配额或独立Key余额
quota_ok, message = UsageService.check_user_quota(db, user, api_key=api_key_obj)
# 创建认证上下文
auth_context = AuthContext(
user_id=user.id,
user_name=user.username,
api_key_id=api_key_obj.id,
api_key_name=api_key_obj.name if hasattr(api_key_obj, "name") else None,
permissions={
"can_use_api": quota_ok,
"is_admin": user.is_admin if hasattr(user, "is_admin") else False,
"is_standalone_key": api_key_obj.is_standalone, # 标记是否为独立余额Key
},
quota_info={
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"remaining_usd": None if user.quota_usd is None else user.quota_usd - user.used_usd,
"quota_ok": quota_ok,
"message": message,
},
metadata={
"auth_method": "api_key",
"client_ip": request.client.host if request.client else "unknown",
"is_standalone": api_key_obj.is_standalone, # 在metadata中也保存一份
},
)
logger.info("API Key认证成功")
return auth_context

120
src/plugins/auth/base.py Normal file
View File

@@ -0,0 +1,120 @@
"""
认证插件基类
定义认证插件的接口和认证上下文
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fastapi import Request
from sqlalchemy.orm import Session
from ..common import BasePlugin, HealthStatus, PluginMetadata
@dataclass
class AuthContext:
"""
认证上下文
包含认证后的用户信息和权限
"""
user_id: int
user_name: str
api_key_id: Optional[int] = None
api_key_name: Optional[str] = None
permissions: Dict[str, bool] = None
quota_info: Dict[str, Any] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.permissions is None:
self.permissions = {}
if self.metadata is None:
self.metadata = {}
class AuthPlugin(BasePlugin):
"""
认证插件基类
所有认证插件必须继承此类并实现authenticate方法
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化认证插件
Args:
name: 插件名称
priority: 优先级(数字越大优先级越高)
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖的其他插件
provides: 提供的服务
config: 配置字典
"""
super().__init__(
name=name,
priority=priority,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies,
provides=provides,
config=config,
)
@abstractmethod
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
"""
执行认证
Args:
request: FastAPI请求对象
db: 数据库会话
Returns:
成功返回AuthContext失败返回None
"""
pass
@abstractmethod
def get_credentials(self, request: Request) -> Optional[str]:
"""
从请求中提取认证凭据
Args:
request: FastAPI请求对象
Returns:
认证凭据字符串如果没有找到返回None
"""
pass
def is_applicable(self, request: Request) -> bool:
"""
检查此插件是否适用于当前请求
Args:
request: FastAPI请求对象
Returns:
如果插件适用返回True
"""
# 默认情况下,如果能提取到凭据就适用
return self.get_credentials(request) is not None

103
src/plugins/auth/jwt.py Normal file
View File

@@ -0,0 +1,103 @@
"""
JWT认证插件
支持JWT Bearer token认证
"""
from typing import Optional
from fastapi import Request
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import User
from src.services.auth.service import AuthService
from .base import AuthContext, AuthPlugin
class JwtAuthPlugin(AuthPlugin):
"""
JWT认证插件
支持从Authorization Bearer header中提取JWT token进行认证
"""
def __init__(self):
super().__init__(name="jwt", priority=20) # 高优先级优先于API Key
def get_credentials(self, request: Request) -> Optional[str]:
"""
从Authorization header中提取JWT token
支持格式: Authorization: Bearer <token>
"""
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.replace("Bearer ", "")
return None
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
"""
使用JWT token进行认证
"""
# 提取JWT token
token = self.get_credentials(request)
if not token:
logger.debug("未找到JWT token")
return None
# 记录认证尝试的详细信息
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
try:
# 验证JWT token
payload = AuthService.verify_token(token)
logger.debug(f"JWT token验证成功, payload: {payload}")
# 从payload中提取用户信息
user_id = payload.get("user_id")
if not user_id:
logger.warning("JWT token中缺少用户ID")
return None
logger.debug(f"从JWT提取user_id: {user_id}, 类型: {type(user_id)}")
# 从数据库获取用户信息
user = db.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"JWT认证失败 - 用户不存在: {user_id}")
return None
logger.debug(f"找到用户: {user.email}, is_active: {user.is_active}")
if not user.is_active:
logger.warning(f"JWT认证失败 - 用户已禁用: {user.email}")
return None
# 创建认证上下文
auth_context = AuthContext(
user_id=user.id,
user_name=user.username,
permissions={"can_use_api": True, "is_admin": user.role.value == "admin"},
quota_info={
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"remaining_usd": (
None if user.quota_usd is None else user.quota_usd - user.used_usd
),
"quota_ok": True, # JWT用户通常已经通过前端验证
},
metadata={
"auth_method": "jwt",
"client_ip": request.client.host if request.client else "unknown",
"token_exp": payload.get("exp"),
},
)
logger.info("JWT认证成功")
return auth_context
except Exception as e:
logger.warning(f"JWT认证失败: {str(e)}")
return None

5
src/plugins/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,5 @@
"""缓存插件包"""
from .base import CachePlugin
__all__ = ["CachePlugin"]

218
src/plugins/cache/base.py vendored Normal file
View File

@@ -0,0 +1,218 @@
"""
缓存插件基类
定义缓存插件的接口
"""
import hashlib
import json
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Dict, List, Optional
from ..common import BasePlugin, HealthStatus, PluginMetadata
class CachePlugin(BasePlugin):
"""
缓存插件基类
所有缓存插件必须继承此类并实现相关方法
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化缓存插件
Args:
name: 插件名称
priority: 优先级
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖列表
provides: 提供服务列表
config: 配置字典
"""
super().__init__(
name=name,
priority=priority,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies,
provides=provides,
config=config,
)
self.default_ttl = self.config.get("default_ttl", 3600) # 默认1小时
self.max_size = self.config.get("max_size", 1000) # 最大缓存项数
@abstractmethod
async def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存键
Returns:
缓存值如果不存在返回None
"""
pass
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
设置缓存值
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间None使用默认值
Returns:
是否成功设置
"""
pass
@abstractmethod
async def delete(self, key: str) -> bool:
"""
删除缓存项
Args:
key: 缓存键
Returns:
是否成功删除
"""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""
检查缓存项是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
pass
@abstractmethod
async def clear(self) -> bool:
"""
清空所有缓存
Returns:
是否成功清空
"""
pass
@abstractmethod
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""
批量获取缓存值
Args:
keys: 缓存键列表
Returns:
键值对字典
"""
pass
@abstractmethod
async def set_many(self, items: Dict[str, Any], ttl: Optional[int] = None) -> bool:
"""
批量设置缓存值
Args:
items: 键值对字典
ttl: 过期时间(秒)
Returns:
是否成功设置
"""
pass
@abstractmethod
async def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
pass
def generate_key(self, *args, **kwargs) -> str:
"""
生成缓存键
Args:
*args: 位置参数
**kwargs: 关键字参数
Returns:
缓存键字符串
"""
# 创建一个稳定的键
key_parts = [str(arg) for arg in args]
key_parts.extend([f"{k}:{v}" for k, v in sorted(kwargs.items())])
key_string = "|".join(key_parts)
# 如果键太长,使用哈希
if len(key_string) > 250:
hash_obj = hashlib.md5(key_string.encode())
return f"{self.name}:{hash_obj.hexdigest()}"
return f"{self.name}:{key_string}"
def serialize(self, value: Any) -> str:
"""
序列化值
Args:
value: 要序列化的值
Returns:
序列化后的字符串
"""
return json.dumps(value, default=str)
def deserialize(self, value: str) -> Any:
"""
反序列化值
Args:
value: 序列化的字符串
Returns:
反序列化后的值
"""
return json.loads(value)
def configure(self, config: Dict[str, Any]):
"""
配置插件
Args:
config: 配置字典
"""
super().configure(config)
self.default_ttl = config.get("default_ttl", self.default_ttl)
self.max_size = config.get("max_size", self.max_size)

195
src/plugins/cache/memory.py vendored Normal file
View File

@@ -0,0 +1,195 @@
"""
内存缓存插件
基于Python字典的简单内存缓存实现
"""
import asyncio
import threading
import time
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from .base import CachePlugin
class MemoryCachePlugin(CachePlugin):
"""
内存缓存插件
使用OrderedDict实现LRU缓存
"""
def __init__(self, name: str = "memory", config: Dict[str, Any] = None):
super().__init__(name, config)
self._cache: OrderedDict = OrderedDict()
self._expiry: Dict[str, float] = {}
self._lock = threading.RLock()
self._hits = 0
self._misses = 0
self._evictions = 0
self._cleanup_task = None
self._cleanup_interval = 60 # 默认值
# 启动清理任务
if config is not None:
self._cleanup_interval = config.get("cleanup_interval", 60)
try:
self._start_cleanup_task()
except:
pass # 忽略事件循环错误
def _start_cleanup_task(self):
"""启动后台清理任务"""
async def cleanup_loop():
while self.enabled:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
loop = asyncio.get_event_loop()
self._cleanup_task = loop.create_task(cleanup_loop())
async def _cleanup_expired(self):
"""清理过期的缓存项"""
now = time.time()
expired_keys = []
with self._lock:
for key, expiry in self._expiry.items():
if expiry < now:
expired_keys.append(key)
for key in expired_keys:
self._cache.pop(key, None)
self._expiry.pop(key, None)
self._evictions += 1
def _check_size(self):
"""检查并维护缓存大小限制"""
if len(self._cache) >= self.max_size:
# 删除最老的项LRU
key = next(iter(self._cache))
self._cache.pop(key)
self._expiry.pop(key, None)
self._evictions += 1
async def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
with self._lock:
# 检查是否过期
if key in self._expiry:
if self._expiry[key] < time.time():
# 已过期,删除
self._cache.pop(key, None)
self._expiry.pop(key)
self._misses += 1
return None
# 获取值并更新访问顺序LRU
if key in self._cache:
value = self._cache.pop(key)
self._cache[key] = value # 移到末尾
self._hits += 1
return self.deserialize(value) if isinstance(value, str) else value
else:
self._misses += 1
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
with self._lock:
# 检查大小限制
if key not in self._cache:
self._check_size()
# 序列化值
if not isinstance(value, str):
value = self.serialize(value)
# 设置值
self._cache[key] = value
self._cache.move_to_end(key) # 移到末尾(最新)
# 设置过期时间
if ttl is None:
ttl = self.default_ttl
if ttl > 0:
self._expiry[key] = time.time() + ttl
return True
async def delete(self, key: str) -> bool:
"""删除缓存项"""
with self._lock:
if key in self._cache:
self._cache.pop(key)
self._expiry.pop(key, None)
return True
return False
async def exists(self, key: str) -> bool:
"""检查缓存项是否存在"""
with self._lock:
# 检查是否过期
if key in self._expiry:
if self._expiry[key] < time.time():
# 已过期,删除
self._cache.pop(key, None)
self._expiry.pop(key)
return False
return key in self._cache
async def clear(self) -> bool:
"""清空所有缓存"""
with self._lock:
self._cache.clear()
self._expiry.clear()
return True
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""批量获取缓存值"""
result = {}
for key in keys:
value = await self.get(key)
if value is not None:
result[key] = value
return result
async def set_many(self, items: Dict[str, Any], ttl: Optional[int] = None) -> bool:
"""批量设置缓存值"""
success = True
for key, value in items.items():
if not await self.set(key, value, ttl):
success = False
return success
async def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_requests = self._hits + self._misses
hit_rate = self._hits / total_requests if total_requests > 0 else 0
return {
"type": "memory",
"size": len(self._cache),
"max_size": self.max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": hit_rate,
"evictions": self._evictions,
"cleanup_interval": self._cleanup_interval,
}
async def _do_shutdown(self):
"""清理资源"""
# 取消清理任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
def __del__(self):
"""清理资源"""
if hasattr(self, "_cleanup_task") and self._cleanup_task:
self._cleanup_task.cancel()

202
src/plugins/common.py Normal file
View File

@@ -0,0 +1,202 @@
"""
插件系统通用定义
包含所有插件类型共享的类和接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
from src.core.logger import logger
class HealthStatus(Enum):
"""插件健康状态"""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
@dataclass
class PluginMetadata:
"""插件元数据"""
name: str
version: str = "1.0.0"
author: str = "Unknown"
description: str = ""
api_version: str = "1.0"
dependencies: List[str] = None
provides: List[str] = None
def __post_init__(self):
if self.dependencies is None:
self.dependencies = []
if self.provides is None:
self.provides = []
class BasePlugin(ABC):
"""
所有插件的基类
定义插件的基本生命周期和元数据管理
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化插件
Args:
name: 插件名称
priority: 优先级(数字越大优先级越高)
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖的其他插件
provides: 提供的服务
config: 配置字典
"""
self.name = name
self.priority = priority
self.enabled = True
self.config = config or {}
self.metadata = PluginMetadata(
name=name,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies or [],
provides=provides or [],
)
self._initialized = False
async def initialize(self) -> bool:
"""
初始化插件
Returns:
初始化成功返回True失败返回False
"""
if self._initialized:
return True
try:
await self._do_initialize()
self._initialized = True
return True
except Exception as e:
logger.warning(f"Failed to initialize plugin {self.name}: {e}")
return False
async def _do_initialize(self):
"""
子类可以重写此方法来实现特定的初始化逻辑
"""
pass
async def shutdown(self):
"""
关闭插件,清理资源
"""
if not self._initialized:
return
try:
await self._do_shutdown()
except Exception as e:
logger.warning(f"Error during plugin {self.name} shutdown: {e}")
finally:
self._initialized = False
async def _do_shutdown(self):
"""
子类可以重写此方法来实现特定的清理逻辑
"""
pass
async def health_check(self) -> HealthStatus:
"""
检查插件健康状态
Returns:
插件健康状态
"""
if not self._initialized or not self.enabled:
return HealthStatus.UNHEALTHY
try:
return await self._do_health_check()
except Exception:
return HealthStatus.UNHEALTHY
async def _do_health_check(self) -> HealthStatus:
"""
子类可以重写此方法来实现特定的健康检查逻辑
默认实现:如果插件已初始化且启用,则认为健康
"""
return (
HealthStatus.HEALTHY if (self._initialized and self.enabled) else HealthStatus.UNHEALTHY
)
def configure(self, config: Dict[str, Any]):
"""
配置插件
Args:
config: 配置字典
"""
self.config.update(config)
self.enabled = config.get("enabled", True)
def get_metadata(self) -> PluginMetadata:
"""
获取插件元数据
Returns:
插件元数据
"""
return self.metadata
@property
def is_initialized(self) -> bool:
"""检查插件是否已初始化"""
return self._initialized
def validate_dependencies(self, available_plugins: Dict[str, List[str]]) -> List[str]:
"""
验证插件依赖是否满足
Args:
available_plugins: 可用插件字典 {plugin_type: [plugin_names]}
Returns:
缺失的依赖列表
"""
missing_deps = []
for dep in self.metadata.dependencies:
found = False
for plugin_type, plugin_names in available_plugins.items():
if dep in plugin_names:
found = True
break
if not found:
missing_deps.append(dep)
return missing_deps
def __repr__(self):
return f"<{self.__class__.__name__}(name={self.name}, priority={self.priority}, enabled={self.enabled}, version={self.metadata.version})>"

View File

@@ -0,0 +1,13 @@
"""
负载均衡策略插件
"""
from .base import LoadBalancerStrategy, ProviderCandidate, SelectionResult
from .sticky_priority import StickyPriorityStrategy
__all__ = [
"LoadBalancerStrategy",
"ProviderCandidate",
"SelectionResult",
"StickyPriorityStrategy",
]

View File

@@ -0,0 +1,134 @@
"""
负载均衡策略基类
定义负载均衡策略的接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..common import BasePlugin
@dataclass
class ProviderCandidate:
"""
候选提供商信息
"""
provider: Any # Provider 对象
priority: int = 0 # 优先级(数字越大优先级越高)
weight: float = 1.0 # 权重(影响被选中的概率)
model: Optional[Any] = None # Model 对象(如果需要模型信息)
metadata: Optional[Dict[str, Any]] = None # 额外元数据
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class SelectionResult:
"""
选择结果
"""
provider: Any # 选中的提供商
priority: int # 该提供商的优先级
weight: float # 该提供商的权重
selection_metadata: Optional[Dict[str, Any]] = None # 选择过程的元数据
def __post_init__(self):
if self.selection_metadata is None:
self.selection_metadata = {}
class LoadBalancerStrategy(BasePlugin):
"""
负载均衡策略基类
所有负载均衡策略必须继承此类
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化负载均衡策略
Args:
name: 策略名称
priority: 优先级(数字越大优先级越高)
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖的其他插件
provides: 提供的服务
config: 配置字典
"""
super().__init__(
name=name,
priority=priority,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies,
provides=provides,
config=config,
)
@abstractmethod
async def select(
self, candidates: List[ProviderCandidate], context: Optional[Dict[str, Any]] = None
) -> Optional[SelectionResult]:
"""
从候选提供商中选择一个
Args:
candidates: 候选提供商列表
context: 上下文信息如请求ID、用户信息等
Returns:
选择结果,如果没有可用提供商则返回 None
"""
pass
@abstractmethod
async def get_stats(self) -> Dict[str, Any]:
"""
获取负载均衡统计信息
Returns:
统计信息字典
"""
pass
async def record_result(
self,
provider: Any,
success: bool,
response_time: Optional[float] = None,
error: Optional[Exception] = None,
):
"""
记录请求结果(用于动态调整策略)
Args:
provider: 提供商对象
success: 是否成功
response_time: 响应时间(秒)
error: 错误信息(如果失败)
"""
# 默认实现为空,子类可以重写来实现动态调整
pass

View File

@@ -0,0 +1,450 @@
"""
粘性优先级负载均衡策略
正常情况下始终选择同一个提供商(优先级最高+权重最大),只在故障时切换
WARNING: 多进程环境注意事项
=============================
此插件的健康状态和粘性缓存存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
每个 worker 进程有独立的状态,可能导致:
- 不同 worker 看到的提供商健康状态不同
- 粘性路由在不同 worker 间不一致
- 统计数据分散在各个 worker 中
解决方案:
1. 单 worker 模式:适用于低流量场景
2. Redis 共享状态:将 _provider_health 和 _sticky_providers 迁移到 Redis
3. 使用独立的健康检查服务:所有 worker 共享同一个健康状态源
目前项目已有 Redis 依赖,建议在高可用场景下将状态迁移到 Redis。
参考src/services/health_monitor.py 中的实现
"""
import time
from collections import defaultdict
from typing import Any, Dict, List, Optional
from src.core.logger import logger
from .base import LoadBalancerStrategy, ProviderCandidate, SelectionResult
class StickyPriorityStrategy(LoadBalancerStrategy):
"""
粘性优先级策略
选择逻辑:
1. 在最高优先级组中,选择权重最大的提供商作为"粘性"提供商
2. 正常情况下,始终选择该粘性提供商
3. 只有在粘性提供商失败时,才切换到同优先级的其他提供商
4. 当粘性提供商恢复后,自动切回
特点:
- 最小化提供商切换,流量集中在单一提供商
- 自动故障转移和恢复
- 适合需要集中使用某个API Key的场景
Note:
状态存储在进程内存中,多进程部署时各 worker 状态独立。
详见模块文档说明。
"""
def __init__(self, config: Dict[str, Any] = None):
config = config or {} # 确保 config 不为 None
super().__init__(
name="sticky_priority",
priority=110, # 比默认的 priority_weighted 更高
version="1.0.0",
author="System",
description="粘性优先级负载均衡策略,正常时始终使用同一提供商",
api_version="1.0",
provides=["load_balancer"],
config=config,
)
# 配置参数
self.failure_threshold = config.get("failure_threshold", 3) # 连续失败阈值
self.recovery_delay = config.get("recovery_delay", 30) # 恢复延迟(秒)
self.enable_auto_recovery = config.get("enable_auto_recovery", True) # 是否自动恢复
# 提供商健康状态追踪 {provider_id: health_info}
self._provider_health: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {
"consecutive_failures": 0,
"last_failure_time": None,
"is_healthy": True,
"total_requests": 0,
"total_failures": 0,
}
)
# 当前粘性提供商缓存 {cache_key: provider_id}
# cache_key 可以是 api_key_id 或者其他标识
self._sticky_providers: Dict[str, str] = {}
# 统计信息
self._stats = {
"total_selections": 0,
"provider_selections": {},
"sticky_hits": 0, # 选择粘性提供商的次数
"failovers": 0, # 故障切换次数
"auto_recoveries": 0, # 自动恢复次数
}
async def select(
self, candidates: List[ProviderCandidate], context: Optional[Dict[str, Any]] = None
) -> Optional[SelectionResult]:
"""
从候选提供商中选择一个
Args:
candidates: 候选提供商列表
context: 上下文信息(包含 api_key_id 等)
Returns:
选择结果
"""
if not candidates:
logger.warning("No candidates available for selection")
return None
if len(candidates) == 1:
candidate = candidates[0]
self._record_selection(candidate.provider, is_sticky=True)
return SelectionResult(
provider=candidate.provider,
priority=candidate.priority,
weight=candidate.weight,
selection_metadata={"strategy": "single_candidate"},
)
# 获取缓存键(用于识别同一请求源)
cache_key = self._get_cache_key(context)
# 按优先级分组
priority_groups = self._group_by_priority(candidates)
highest_priority = max(priority_groups.keys())
highest_group = priority_groups[highest_priority]
# 确定粘性提供商
sticky_candidate = self._determine_sticky_provider(highest_group, cache_key, context)
# 检查粘性提供商是否健康
provider_id = str(sticky_candidate.provider.id)
health_info = self._provider_health[provider_id]
# 如果粘性提供商健康,直接使用
if health_info["is_healthy"]:
self._record_selection(sticky_candidate.provider, is_sticky=True)
logger.info(f"Selected sticky provider {sticky_candidate.provider.name}")
return SelectionResult(
provider=sticky_candidate.provider,
priority=sticky_candidate.priority,
weight=sticky_candidate.weight,
selection_metadata={
"strategy": "sticky_priority",
"is_sticky": True,
"cache_key": cache_key,
"health_status": "healthy",
},
)
# 粘性提供商不健康,选择备用提供商
logger.warning(f"Sticky provider {sticky_candidate.provider.name} is unhealthy, selecting backup")
# 从同一优先级组中选择健康的备用提供商
backup_candidate = self._select_backup_provider(highest_group)
if not backup_candidate:
# 如果没有健康的备用,降级使用不健康的粘性提供商
logger.warning("No healthy backup provider available, falling back to unhealthy sticky provider")
backup_candidate = sticky_candidate
self._record_selection(backup_candidate.provider, is_sticky=False)
self._stats["failovers"] += 1
logger.info(f"Selected backup provider {backup_candidate.provider.name}")
return SelectionResult(
provider=backup_candidate.provider,
priority=backup_candidate.priority,
weight=backup_candidate.weight,
selection_metadata={
"strategy": "sticky_priority",
"is_sticky": False,
"is_failover": True,
"original_provider_id": provider_id,
"health_status": "backup",
},
)
def _get_cache_key(self, context: Optional[Dict[str, Any]]) -> str:
"""
生成缓存键,用于识别同一请求源
Args:
context: 上下文信息
Returns:
缓存键
"""
if not context:
return "default"
# 优先使用 api_key_id
if "api_key_id" in context:
return f"api_key_{context['api_key_id']}"
# 其他标识
if "user_id" in context:
return f"user_{context['user_id']}"
return "default"
def _group_by_priority(
self, candidates: List[ProviderCandidate]
) -> Dict[int, List[ProviderCandidate]]:
"""按优先级分组候选提供商"""
groups: Dict[int, List[ProviderCandidate]] = {}
for candidate in candidates:
priority = candidate.priority
if priority not in groups:
groups[priority] = []
groups[priority].append(candidate)
return groups
def _determine_sticky_provider(
self,
candidates: List[ProviderCandidate],
cache_key: str,
context: Optional[Dict[str, Any]] = None,
) -> ProviderCandidate:
"""
确定粘性提供商
策略:
1. 如果已有缓存的粘性提供商,检查是否仍在候选列表中
2. 如果没有或已失效,选择权重最大的作为新的粘性提供商
Args:
candidates: 同一优先级的候选列表
cache_key: 缓存键
context: 上下文信息
Returns:
粘性提供商候选
"""
# 检查缓存的粘性提供商
if cache_key in self._sticky_providers:
cached_provider_id = self._sticky_providers[cache_key]
# 查找是否仍在候选列表中
for candidate in candidates:
if str(candidate.provider.id) == cached_provider_id:
# 检查是否可以自动恢复
if self._can_auto_recover(cached_provider_id):
logger.info(f"Auto-recovering sticky provider {candidate.provider.name}")
self._stats["auto_recoveries"] += 1
# 重置健康状态
self._provider_health[cached_provider_id]["is_healthy"] = True
self._provider_health[cached_provider_id]["consecutive_failures"] = 0
return candidate
# 没有缓存或缓存失效,选择权重最大的
sticky_candidate = max(candidates, key=lambda c: c.weight)
self._sticky_providers[cache_key] = str(sticky_candidate.provider.id)
logger.info(f"Set new sticky provider {sticky_candidate.provider.name}")
return sticky_candidate
def _can_auto_recover(self, provider_id: str) -> bool:
"""
检查提供商是否可以自动恢复
Args:
provider_id: 提供商ID
Returns:
是否可以恢复
"""
if not self.enable_auto_recovery:
return False
health_info = self._provider_health[provider_id]
# 如果已经是健康状态,直接返回 True
if health_info["is_healthy"]:
return True
# 检查是否超过恢复延迟
if health_info["last_failure_time"]:
time_since_failure = time.time() - health_info["last_failure_time"]
if time_since_failure >= self.recovery_delay:
return True
return False
def _select_backup_provider(
self, candidates: List[ProviderCandidate]
) -> Optional[ProviderCandidate]:
"""
从候选列表中选择健康的备用提供商
优先选择权重最大且健康的提供商
Args:
candidates: 候选提供商列表
Returns:
备用提供商,如果没有健康的则返回 None
"""
healthy_candidates = []
for candidate in candidates:
provider_id = str(candidate.provider.id)
health_info = self._provider_health[provider_id]
# 检查是否可以自动恢复
if health_info["is_healthy"] or self._can_auto_recover(provider_id):
if not health_info["is_healthy"]:
# 自动恢复
health_info["is_healthy"] = True
health_info["consecutive_failures"] = 0
self._stats["auto_recoveries"] += 1
healthy_candidates.append(candidate)
if not healthy_candidates:
return None
# 选择权重最大的健康提供商
return max(healthy_candidates, key=lambda c: c.weight)
def _record_selection(self, provider: Any, is_sticky: bool = True):
"""记录选择统计"""
self._stats["total_selections"] += 1
provider_id = str(provider.id)
if provider_id not in self._stats["provider_selections"]:
self._stats["provider_selections"][provider_id] = 0
self._stats["provider_selections"][provider_id] += 1
if is_sticky:
self._stats["sticky_hits"] += 1
async def record_result(
self,
provider: Any,
success: bool,
response_time: Optional[float] = None,
error: Optional[Exception] = None,
):
"""
记录请求结果,更新健康状态
Args:
provider: 提供商对象
success: 是否成功
response_time: 响应时间(秒)
error: 错误信息(如果失败)
"""
provider_id = str(provider.id)
health_info = self._provider_health[provider_id]
health_info["total_requests"] += 1
if success:
# 成功,重置连续失败计数
health_info["consecutive_failures"] = 0
health_info["is_healthy"] = True
logger.debug(f"Recorded successful result for provider {provider.name}")
else:
# 失败,增加连续失败计数
health_info["consecutive_failures"] += 1
health_info["total_failures"] += 1
health_info["last_failure_time"] = time.time()
# 检查是否达到失败阈值
if health_info["consecutive_failures"] >= self.failure_threshold:
health_info["is_healthy"] = False
logger.warning(f"Provider {provider.name} marked as unhealthy")
else:
logger.debug(f"Recorded failed result for provider {provider.name}")
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
# 计算健康状态
healthy_count = sum(1 for info in self._provider_health.values() if info["is_healthy"])
total_providers = len(self._provider_health)
# 计算粘性命中率
sticky_hit_rate = 0.0
if self._stats["total_selections"] > 0:
sticky_hit_rate = self._stats["sticky_hits"] / self._stats["total_selections"]
return {
"strategy": "sticky_priority",
"total_selections": self._stats["total_selections"],
"provider_selections": self._stats["provider_selections"],
"sticky_hits": self._stats["sticky_hits"],
"sticky_hit_rate": sticky_hit_rate,
"failovers": self._stats["failovers"],
"auto_recoveries": self._stats["auto_recoveries"],
"healthy_providers": healthy_count,
"total_providers": total_providers,
"provider_health": {
provider_id: {
"is_healthy": info["is_healthy"],
"consecutive_failures": info["consecutive_failures"],
"total_requests": info["total_requests"],
"total_failures": info["total_failures"],
"failure_rate": (
info["total_failures"] / info["total_requests"]
if info["total_requests"] > 0
else 0
),
"last_failure_time": info["last_failure_time"],
}
for provider_id, info in self._provider_health.items()
},
"sticky_providers": self._sticky_providers,
"config": {
"failure_threshold": self.failure_threshold,
"recovery_delay": self.recovery_delay,
"enable_auto_recovery": self.enable_auto_recovery,
},
}
async def reset_provider_health(self, provider_id: str):
"""重置指定提供商的健康状态"""
if provider_id in self._provider_health:
self._provider_health[provider_id] = {
"consecutive_failures": 0,
"last_failure_time": None,
"is_healthy": True,
"total_requests": 0,
"total_failures": 0,
}
logger.info(f"Reset health status for provider {provider_id}")
async def clear_sticky_cache(self, cache_key: Optional[str] = None):
"""
清除粘性提供商缓存
Args:
cache_key: 指定要清除的缓存键None 则清除全部
"""
if cache_key:
if cache_key in self._sticky_providers:
del self._sticky_providers[cache_key]
logger.info(f"Cleared sticky provider cache for key: {cache_key}")
else:
self._sticky_providers.clear()
logger.info("Cleared all sticky provider cache")

579
src/plugins/manager.py Normal file
View File

@@ -0,0 +1,579 @@
"""
插件管理器
统一管理和协调所有插件系统
"""
import asyncio
import importlib
import inspect
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
from src.core.logger import logger
from src.plugins.auth.base import AuthPlugin
from src.plugins.cache.base import CachePlugin
# 移除审计插件 - 审计功能现在是核心服务,不再作为插件
from src.plugins.common import BasePlugin, HealthStatus, PluginMetadata
from src.plugins.load_balancer.base import LoadBalancerStrategy
from src.plugins.monitor.base import MonitorPlugin
from src.plugins.notification.base import NotificationPlugin
from src.plugins.rate_limit.base import RateLimitStrategy
from src.plugins.token.base import TokenCounterPlugin
class PluginManager:
"""
统一的插件管理器
负责加载、配置和管理所有类型的插件
"""
# 当前支持的 API 版本
SUPPORTED_API_VERSION = "1.0"
# 插件类型映射
PLUGIN_TYPES = {
"auth": AuthPlugin,
"rate_limit": RateLimitStrategy,
"cache": CachePlugin,
"monitor": MonitorPlugin,
"token": TokenCounterPlugin,
"notification": NotificationPlugin,
"load_balancer": LoadBalancerStrategy,
# 移除 "audit" - 审计功能现在是核心服务
}
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化插件管理器
Args:
config: 配置字典
"""
self.config = config or {}
self.plugins: Dict[str, Dict[str, Any]] = {
"auth": {},
"rate_limit": {},
"cache": {},
"monitor": {},
"token": {},
"notification": {},
"load_balancer": {},
# 移除 "audit" - 审计功能现在是核心服务
}
self.default_plugins: Dict[str, Optional[str]] = {
"auth": None,
"rate_limit": None,
"cache": None,
"monitor": None,
"token": None,
"notification": None,
"load_balancer": "sticky_priority", # 默认使用粘性优先级策略
# 移除 "audit" - 审计功能现在是核心服务
}
# 跟踪因版本不兼容而跳过的插件
self._incompatible_plugins: List[str] = []
# 自动发现和加载插件
self._auto_discover_plugins()
# 应用配置
self._apply_config()
def _auto_discover_plugins(self):
"""自动发现和加载插件"""
plugins_dir = Path(__file__).parent
for plugin_type in self.PLUGIN_TYPES:
type_dir = plugins_dir / plugin_type
if not type_dir.exists():
continue
# 扫描插件目录
for file_path in type_dir.glob("*.py"):
if file_path.name.startswith("_") or file_path.name == "base.py":
continue
module_name = f"src.plugins.{plugin_type}.{file_path.stem}"
try:
module = importlib.import_module(module_name)
self._load_plugin_from_module(module, plugin_type)
except Exception as e:
logger.error(f"Failed to load plugin module {module_name}: {e}")
def _is_api_version_compatible(self, plugin_api_version: str) -> bool:
"""
检查插件 API 版本是否兼容
采用语义化版本的主版本号兼容策略:
- 主版本号相同则兼容
- 例如: 支持版本 "1.0",插件版本 "1.0", "1.1", "1.2" 都兼容
Args:
plugin_api_version: 插件声明的 API 版本
Returns:
是否兼容
"""
try:
supported_major = self.SUPPORTED_API_VERSION.split(".")[0]
plugin_major = plugin_api_version.split(".")[0]
return supported_major == plugin_major
except (ValueError, IndexError):
# 解析失败,假设兼容
return True
def _load_plugin_from_module(self, module: Any, plugin_type: str):
"""从模块加载插件类"""
base_class = self.PLUGIN_TYPES[plugin_type]
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, base_class) and obj != base_class:
# 实例化插件
try:
plugin_instance = obj()
# 检查 API 版本兼容性
plugin_api_version = getattr(plugin_instance.metadata, "api_version", "1.0")
if not self._is_api_version_compatible(plugin_api_version):
logger.warning(f"Plugin {plugin_instance.name} has incompatible API version "
f"{plugin_api_version} (supported: {self.SUPPORTED_API_VERSION}), "
f"plugin will be disabled")
plugin_instance.enabled = False
self._incompatible_plugins.append(plugin_instance.name)
self.register_plugin(plugin_type, plugin_instance)
logger.info(f"Loaded {plugin_type} plugin: {plugin_instance.name}")
except Exception as e:
logger.error(f"Failed to instantiate plugin {name}: {e}")
def _apply_config(self):
"""应用配置到插件"""
for plugin_type, plugins in self.plugins.items():
type_config = self.config.get(plugin_type, {})
# 设置默认插件
if "default" in type_config:
self.default_plugins[plugin_type] = type_config["default"]
# 配置各个插件
for plugin_name, plugin in plugins.items():
plugin_config = type_config.get(plugin_name, {})
if plugin_config:
plugin.configure(plugin_config)
def register_plugin(self, plugin_type: str, plugin: Any, set_as_default: bool = False):
"""
注册插件
Args:
plugin_type: 插件类型
plugin: 插件实例
set_as_default: 是否设为默认
"""
if plugin_type not in self.plugins:
raise ValueError(f"Unknown plugin type: {plugin_type}")
# 验证插件类型
base_class = self.PLUGIN_TYPES[plugin_type]
if not isinstance(plugin, base_class):
raise TypeError(
f"Plugin must be instance of {base_class.__name__}, " f"got {type(plugin).__name__}"
)
# 注册插件
self.plugins[plugin_type][plugin.name] = plugin
# 设为默认
if set_as_default or not self.default_plugins[plugin_type]:
self.default_plugins[plugin_type] = plugin.name
logger.debug(f"Registered {plugin_type} plugin: {plugin.name}")
def unregister_plugin(self, plugin_type: str, plugin_name: str):
"""
注销插件
Args:
plugin_type: 插件类型
plugin_name: 插件名称
"""
if plugin_type in self.plugins:
if plugin_name in self.plugins[plugin_type]:
del self.plugins[plugin_type][plugin_name]
# 如果是默认插件,清除默认设置
if self.default_plugins[plugin_type] == plugin_name:
self.default_plugins[plugin_type] = None
logger.debug(f"Unregistered {plugin_type} plugin: {plugin_name}")
def get_plugin(self, plugin_type: str, plugin_name: Optional[str] = None) -> Optional[Any]:
"""
获取插件实例
Args:
plugin_type: 插件类型
plugin_name: 插件名称,不指定则返回默认插件
Returns:
插件实例如果不存在返回None
"""
if plugin_type not in self.plugins:
return None
if plugin_name:
return self.plugins[plugin_type].get(plugin_name)
# 返回默认插件
default_name = self.default_plugins[plugin_type]
if default_name:
return self.plugins[plugin_type].get(default_name)
# 如果没有默认插件,返回第一个可用的
if self.plugins[plugin_type]:
return next(iter(self.plugins[plugin_type].values()))
return None
def get_plugins_by_type(self, plugin_type: str) -> List[Any]:
"""
获取某个类型的所有插件
Args:
plugin_type: 插件类型
Returns:
插件列表
"""
if plugin_type not in self.plugins:
return []
return list(self.plugins[plugin_type].values())
def get_enabled_plugins(self, plugin_type: str) -> List[Any]:
"""
获取某个类型的所有启用的插件
Args:
plugin_type: 插件类型
Returns:
启用的插件列表
"""
plugins = self.get_plugins_by_type(plugin_type)
return [p for p in plugins if getattr(p, "enabled", True)]
async def execute_plugin_chain(
self, plugin_type: str, method_name: str, *args, **kwargs
) -> Any:
"""
执行插件链(按优先级)
Args:
plugin_type: 插件类型
method_name: 要调用的方法名
*args: 位置参数
**kwargs: 关键字参数
Returns:
第一个成功的结果
"""
plugins = self.get_enabled_plugins(plugin_type)
# 按优先级排序如果有priority属性
plugins.sort(key=lambda p: getattr(p, "priority", 0), reverse=True)
for plugin in plugins:
if hasattr(plugin, method_name):
method = getattr(plugin, method_name)
try:
if asyncio.iscoroutinefunction(method):
result = await method(*args, **kwargs)
else:
result = method(*args, **kwargs)
if result is not None:
return result
except Exception as e:
logger.error(f"Plugin {plugin.name} failed in {method_name}: {e}")
return None
def get_stats(self) -> Dict[str, Any]:
"""
获取插件管理器统计信息
Returns:
统计信息字典
"""
stats = {
"supported_api_version": self.SUPPORTED_API_VERSION,
"plugin_counts": {},
"enabled_counts": {},
"default_plugins": self.default_plugins,
"plugin_details": {},
"incompatible_plugins": self._incompatible_plugins,
}
for plugin_type in self.PLUGIN_TYPES:
all_plugins = self.get_plugins_by_type(plugin_type)
enabled_plugins = self.get_enabled_plugins(plugin_type)
stats["plugin_counts"][plugin_type] = len(all_plugins)
stats["enabled_counts"][plugin_type] = len(enabled_plugins)
# 详细信息
stats["plugin_details"][plugin_type] = [
{
"name": p.name,
"enabled": getattr(p, "enabled", True),
"priority": getattr(p, "priority", 0),
"class": type(p).__name__,
"api_version": getattr(p.metadata, "api_version", "unknown"),
"version": getattr(p.metadata, "version", "unknown"),
}
for p in all_plugins
]
return stats
async def initialize_all(self) -> Dict[str, bool]:
"""
初始化所有插件
初始化失败的插件会被自动禁用,防止后续使用未正确初始化的插件。
Returns:
初始化结果字典 {plugin_name: success}
"""
results = {}
# 获取所有插件并按依赖顺序排序
all_plugins = []
for plugin_type in self.PLUGIN_TYPES:
all_plugins.extend(self.get_plugins_by_type(plugin_type))
# 拓扑排序处理依赖
sorted_plugins = self._sort_plugins_by_dependencies(all_plugins)
# 按顺序初始化插件
for plugin in sorted_plugins:
try:
# 检查插件是否有 initialize 方法
if not hasattr(plugin, "initialize"):
# 如果没有 initialize 方法,假设插件已经初始化完成
logger.debug(f"Plugin {plugin.name} has no initialize() method, skipping")
results[f"{plugin.name}"] = True
continue
success = await plugin.initialize()
results[f"{plugin.name}"] = success
if success:
logger.info(f"Successfully initialized plugin: {plugin.name}")
else:
# 初始化失败,禁用插件
plugin.enabled = False
logger.error(f"Failed to initialize plugin: {plugin.name}, plugin has been disabled")
except Exception as e:
results[f"{plugin.name}"] = False
# 初始化异常,禁用插件
plugin.enabled = False
logger.error(f"Error initializing plugin {plugin.name}: {e}, plugin has been disabled")
return results
async def shutdown_all(self):
"""
关闭所有插件
"""
# 获取所有插件并按依赖顺序反向排序(先关闭依赖者)
all_plugins = []
for plugin_type in self.PLUGIN_TYPES:
all_plugins.extend(self.get_plugins_by_type(plugin_type))
sorted_plugins = self._sort_plugins_by_dependencies(all_plugins)
sorted_plugins.reverse() # 反向关闭
# 并发关闭插件
shutdown_tasks = []
for plugin in sorted_plugins:
# 只关闭有 shutdown 方法的插件
if hasattr(plugin, "shutdown"):
shutdown_tasks.append(plugin.shutdown())
if shutdown_tasks:
try:
await asyncio.gather(*shutdown_tasks, return_exceptions=True)
logger.info("All plugins shut down")
except Exception as e:
logger.error(f"Error during plugin shutdown: {e}")
def _sort_plugins_by_dependencies(self, plugins: List[BasePlugin]) -> List[BasePlugin]:
"""
按依赖关系对插件进行拓扑排序
Args:
plugins: 插件列表
Returns:
排序后的插件列表
Note:
存在循环依赖的插件会被自动禁用
"""
# 创建插件名称到插件对象的映射
plugin_map = {plugin.name: plugin for plugin in plugins}
# 计算每个插件的入度(被依赖的次数)
in_degree = {plugin.name: 0 for plugin in plugins}
# 构建依赖图
for plugin in plugins:
for dep in plugin.metadata.dependencies:
if dep in in_degree:
in_degree[plugin.name] += 1
# 拓扑排序
queue = [name for name, degree in in_degree.items() if degree == 0]
result = []
while queue:
current = queue.pop(0)
result.append(plugin_map[current])
# 减少依赖当前插件的其他插件的入度
current_plugin = plugin_map[current]
for plugin in plugins:
if current in plugin.metadata.dependencies:
in_degree[plugin.name] -= 1
if in_degree[plugin.name] == 0:
queue.append(plugin.name)
# 检查是否存在循环依赖
if len(result) != len(plugins):
remaining = [p for p in plugins if p not in result]
circular_names = [p.name for p in remaining]
logger.error(f"Circular dependency detected among plugins: {circular_names}. "
f"These plugins will be disabled.")
# 禁用存在循环依赖的插件,而不是继续加载
for plugin in remaining:
plugin.enabled = False
logger.warning(f"Plugin {plugin.name} has been disabled due to circular dependency")
# 不再将循环依赖的插件添加到结果中
return result
async def health_check_all(self) -> Dict[str, HealthStatus]:
"""
检查所有插件的健康状态
Returns:
健康状态字典 {plugin_name: status}
"""
results = {}
# 获取所有插件
all_plugins = []
for plugin_type in self.PLUGIN_TYPES:
all_plugins.extend(self.get_plugins_by_type(plugin_type))
# 并发检查健康状态
health_tasks = []
for plugin in all_plugins:
health_tasks.append(plugin.health_check())
if health_tasks:
health_results = await asyncio.gather(*health_tasks, return_exceptions=True)
for plugin, result in zip(all_plugins, health_results):
if isinstance(result, Exception):
results[plugin.name] = HealthStatus.UNHEALTHY
else:
results[plugin.name] = result
return results
def validate_plugin_dependencies(self) -> Dict[str, List[str]]:
"""
验证所有插件的依赖关系
Returns:
验证结果字典 {plugin_name: [missing_dependencies]}
"""
results = {}
# 获取所有可用插件名称
available_plugins = {}
for plugin_type in self.PLUGIN_TYPES:
available_plugins[plugin_type] = [p.name for p in self.get_plugins_by_type(plugin_type)]
# 检查每个插件的依赖
for plugin_type in self.PLUGIN_TYPES:
for plugin in self.get_plugins_by_type(plugin_type):
missing_deps = plugin.validate_dependencies(available_plugins)
if missing_deps:
results[plugin.name] = missing_deps
return results
def reload_plugin_config(
self, plugin_type: str, plugin_name: str, new_config: Dict[str, Any]
) -> bool:
"""
重新加载插件配置
Args:
plugin_type: 插件类型
plugin_name: 插件名称
new_config: 新配置
Returns:
是否成功重新加载
"""
plugin = self.get_plugin(plugin_type, plugin_name)
if not plugin:
return False
try:
plugin.configure(new_config)
logger.info(f"Reloaded config for plugin {plugin_name}: {new_config}")
return True
except Exception as e:
logger.error(f"Failed to reload config for plugin {plugin_name}: {e}")
return False
# 全局插件管理器实例
_plugin_manager: Optional[PluginManager] = None
_plugin_manager_lock = threading.Lock()
def get_plugin_manager(config: Optional[Dict[str, Any]] = None) -> PluginManager:
"""
获取全局插件管理器实例(线程安全)
Args:
config: 配置字典
Returns:
插件管理器实例
"""
global _plugin_manager
if _plugin_manager is None:
with _plugin_manager_lock:
# 双重检查锁定模式
if _plugin_manager is None:
_plugin_manager = PluginManager(config)
return _plugin_manager
def reset_plugin_manager():
"""重置插件管理器(用于测试)"""
global _plugin_manager
with _plugin_manager_lock:
_plugin_manager = None

View File

@@ -0,0 +1,5 @@
"""监控插件包"""
from .base import Metric, MetricType, MonitorPlugin
__all__ = ["MonitorPlugin", "Metric", "MetricType"]

250
src/plugins/monitor/base.py Normal file
View File

@@ -0,0 +1,250 @@
"""
监控插件基类
定义监控和指标收集的接口
"""
from abc import abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional
from src.plugins.common import BasePlugin
class MetricType(Enum):
"""指标类型"""
COUNTER = "counter" # 计数器(只增不减)
GAUGE = "gauge" # 仪表(可增可减)
HISTOGRAM = "histogram" # 直方图(分布)
SUMMARY = "summary" # 摘要(分位数)
class Metric:
"""指标数据"""
def __init__(
self,
name: str,
value: float,
metric_type: MetricType,
labels: Optional[Dict[str, str]] = None,
timestamp: Optional[datetime] = None,
description: Optional[str] = None,
):
self.name = name
self.value = value
self.metric_type = metric_type
self.labels = labels or {}
self.timestamp = timestamp or datetime.now(timezone.utc)
self.description = description
class MonitorPlugin(BasePlugin):
"""
监控插件基类
所有监控插件必须继承此类并实现相关方法
"""
def __init__(self, name: str, config: Dict[str, Any] = None):
"""
初始化监控插件
Args:
name: 插件名称
config: 配置字典
"""
# 调用父类初始化设置metadata
super().__init__(name=name, config=config, description="Monitor Plugin", version="1.0.0")
self.flush_interval = self.config.get("flush_interval", 60)
self.batch_size = self.config.get("batch_size", 100)
@abstractmethod
async def record_metric(self, metric: Metric):
"""
记录单个指标
Args:
metric: 指标数据
"""
pass
@abstractmethod
async def record_batch(self, metrics: List[Metric]):
"""
批量记录指标
Args:
metrics: 指标列表
"""
pass
@abstractmethod
async def increment(self, name: str, value: float = 1, labels: Optional[Dict[str, str]] = None):
"""
增加计数器
Args:
name: 指标名称
value: 增加的值
labels: 标签字典
"""
pass
@abstractmethod
async def gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
"""
设置仪表值
Args:
name: 指标名称
value: 仪表值
labels: 标签字典
"""
pass
@abstractmethod
async def histogram(
self,
name: str,
value: float,
labels: Optional[Dict[str, str]] = None,
buckets: Optional[List[float]] = None,
):
"""
记录直方图数据
Args:
name: 指标名称
value: 观测值
labels: 标签字典
buckets: 桶边界
"""
pass
@abstractmethod
async def timing(self, name: str, duration: float, labels: Optional[Dict[str, str]] = None):
"""
记录时间指标
Args:
name: 指标名称
duration: 持续时间(秒)
labels: 标签字典
"""
pass
@abstractmethod
async def flush(self):
"""
刷新缓冲的指标到后端
"""
pass
@abstractmethod
async def get_stats(self) -> Dict[str, Any]:
"""
获取插件统计信息
Returns:
统计信息字典
"""
pass
def record_request(
self,
method: str,
endpoint: str,
status_code: int,
duration: float,
provider: Optional[str] = None,
model: Optional[str] = None,
):
"""
记录API请求指标便捷方法
Args:
method: HTTP方法
endpoint: 端点路径
status_code: 状态码
duration: 请求时长
provider: 提供商名称
model: 模型名称
"""
labels = {
"method": method,
"endpoint": endpoint,
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
if provider:
labels["provider"] = provider
if model:
labels["model"] = model
# 异步记录指标
import asyncio
loop = asyncio.get_event_loop()
# 请求计数
loop.create_task(self.increment("http_requests_total", labels=labels))
# 请求延迟
loop.create_task(self.histogram("http_request_duration_seconds", duration, labels=labels))
# 错误计数
if status_code >= 400:
loop.create_task(self.increment("http_errors_total", labels=labels))
def record_token_usage(
self,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cost: Optional[float] = None,
):
"""
记录Token使用指标便捷方法
Args:
provider: 提供商名称
model: 模型名称
input_tokens: 输入token数
output_tokens: 输出token数
cost: 费用
"""
labels = {"provider": provider, "model": model}
import asyncio
loop = asyncio.get_event_loop()
# Token计数
loop.create_task(self.increment("tokens_input_total", input_tokens, labels=labels))
loop.create_task(self.increment("tokens_output_total", output_tokens, labels=labels))
loop.create_task(
self.increment("tokens_total", input_tokens + output_tokens, labels=labels)
)
# 费用
if cost is not None:
loop.create_task(self.increment("usage_cost_total", cost, labels=labels))
def configure(self, config: Dict[str, Any]):
"""
配置插件
Args:
config: 配置字典
"""
self.config.update(config)
self.enabled = config.get("enabled", True)
self.flush_interval = config.get("flush_interval", self.flush_interval)
self.batch_size = config.get("batch_size", self.batch_size)
def __repr__(self):
return f"<{self.__class__.__name__}(name={self.name}, enabled={self.enabled})>"

View File

@@ -0,0 +1,320 @@
"""
Prometheus监控插件
支持将指标导出到Prometheus
"""
import asyncio
import time
from collections import defaultdict
from typing import Any, Dict, List, Optional
try:
from prometheus_client import REGISTRY, Counter, Gauge, Histogram, Summary, generate_latest
PROMETHEUS_AVAILABLE = True
except ImportError:
# Prometheus client not installed, plugin will be disabled
PROMETHEUS_AVAILABLE = False
Counter = Gauge = Histogram = Summary = REGISTRY = generate_latest = None
from .base import Metric, MetricType, MonitorPlugin
from src.core.logger import logger
class PrometheusPlugin(MonitorPlugin):
"""
Prometheus监控插件
使用prometheus_client库导出指标
"""
def __init__(self, name: str = "prometheus", config: Dict[str, Any] = None):
super().__init__(name, config)
# Check if prometheus_client is available
if not PROMETHEUS_AVAILABLE:
self.enabled = False
logger.warning("Prometheus client not installed, plugin disabled")
return
# 指标注册表
self._metrics: Dict[str, Any] = {}
self._buffer: List[Metric] = []
self._lock = asyncio.Lock()
self._flush_task: Optional[asyncio.Task] = None # 跟踪后台任务
# 预定义常用指标
self._init_default_metrics()
# 启动刷新任务
self._start_flush_task()
def _init_default_metrics(self):
"""初始化默认指标"""
# HTTP请求指标
http_label_names = ["method", "endpoint", "status", "status_class"]
self._metrics["http_requests_total"] = Counter(
"http_requests_total",
"Total HTTP requests",
http_label_names,
)
self._metrics["http_request_duration_seconds"] = Histogram(
"http_request_duration_seconds",
"HTTP request duration in seconds",
http_label_names,
buckets=(0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10),
)
self._metrics["http_errors_total"] = Counter(
"http_errors_total",
"Total HTTP errors",
http_label_names,
)
# Token使用指标
self._metrics["tokens_input_total"] = Counter(
"tokens_input_total", "Total input tokens", ["provider", "model"]
)
self._metrics["tokens_output_total"] = Counter(
"tokens_output_total", "Total output tokens", ["provider", "model"]
)
self._metrics["tokens_total"] = Counter(
"tokens_total", "Total tokens", ["provider", "model"]
)
self._metrics["usage_cost_total"] = Counter(
"usage_cost_total", "Total usage cost in USD", ["provider", "model"]
)
# 系统指标
self._metrics["active_connections"] = Gauge(
"active_connections", "Number of active connections"
)
self._metrics["cache_hits_total"] = Counter(
"cache_hits_total", "Total cache hits", ["cache_type"]
)
self._metrics["cache_misses_total"] = Counter(
"cache_misses_total", "Total cache misses", ["cache_type"]
)
# 提供商健康指标
self._metrics["provider_health"] = Gauge(
"provider_health", "Provider health status (1=healthy, 0=unhealthy)", ["provider"]
)
self._metrics["provider_latency_seconds"] = Histogram(
"provider_latency_seconds",
"Provider response latency in seconds",
["provider", "model"],
buckets=(0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30),
)
def _start_flush_task(self):
"""启动定期刷新任务"""
async def flush_loop():
try:
while self.enabled:
await asyncio.sleep(self.flush_interval)
if self.enabled: # 再次检查,避免关闭时执行
await self.flush()
except asyncio.CancelledError:
# 任务被取消,正常关闭
logger.debug("Prometheus flush task cancelled")
except Exception as e:
logger.error(f"Prometheus flush loop error: {e}")
# 保存任务句柄以便后续取消
try:
loop = asyncio.get_event_loop()
self._flush_task = loop.create_task(flush_loop())
except RuntimeError:
# 如果没有运行的事件循环,任务将在后续创建
logger.warning("No event loop available for Prometheus flush task")
def _get_or_create_metric(self, name: str, metric_type: MetricType, labels: List[str] = None):
"""获取或创建指标"""
if name not in self._metrics:
labels = labels or []
if metric_type == MetricType.COUNTER:
self._metrics[name] = Counter(name, f"Auto-created counter {name}", labels)
elif metric_type == MetricType.GAUGE:
self._metrics[name] = Gauge(name, f"Auto-created gauge {name}", labels)
elif metric_type == MetricType.HISTOGRAM:
self._metrics[name] = Histogram(name, f"Auto-created histogram {name}", labels)
elif metric_type == MetricType.SUMMARY:
self._metrics[name] = Summary(name, f"Auto-created summary {name}", labels)
return self._metrics[name]
async def record_metric(self, metric: Metric):
"""记录单个指标"""
async with self._lock:
self._buffer.append(metric)
# 如果缓冲区满,自动刷新
if len(self._buffer) >= self.batch_size:
await self.flush()
async def record_batch(self, metrics: List[Metric]):
"""批量记录指标"""
async with self._lock:
self._buffer.extend(metrics)
# 如果缓冲区满,自动刷新
if len(self._buffer) >= self.batch_size:
await self.flush()
async def increment(self, name: str, value: float = 1, labels: Optional[Dict[str, str]] = None):
"""增加计数器"""
try:
if name in self._metrics:
metric = self._metrics[name]
if labels:
# 过滤掉不存在的标签
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
metric.labels(**filtered_labels).inc(value)
else:
metric.inc(value)
else:
# 创建新的计数器
label_names = list(labels.keys()) if labels else []
metric = self._get_or_create_metric(name, MetricType.COUNTER, label_names)
if labels:
metric.labels(**labels).inc(value)
else:
metric.inc(value)
except Exception as e:
# 记录错误但不中断
logger.warning(f"Error recording metric {name}: {e}")
async def gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
"""设置仪表值"""
try:
if name in self._metrics:
metric = self._metrics[name]
if labels:
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
metric.labels(**filtered_labels).set(value)
else:
metric.set(value)
else:
# 创建新的仪表
label_names = list(labels.keys()) if labels else []
metric = self._get_or_create_metric(name, MetricType.GAUGE, label_names)
if labels:
metric.labels(**labels).set(value)
else:
metric.set(value)
except Exception as e:
logger.warning(f"Error recording gauge {name}: {e}")
async def histogram(
self,
name: str,
value: float,
labels: Optional[Dict[str, str]] = None,
buckets: Optional[List[float]] = None,
):
"""记录直方图数据"""
try:
if name in self._metrics:
metric = self._metrics[name]
if labels:
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
metric.labels(**filtered_labels).observe(value)
else:
metric.observe(value)
else:
# 创建新的直方图
label_names = list(labels.keys()) if labels else []
if buckets:
metric = Histogram(
name, f"Auto-created histogram {name}", label_names, buckets=buckets
)
else:
metric = self._get_or_create_metric(name, MetricType.HISTOGRAM, label_names)
self._metrics[name] = metric
if labels:
metric.labels(**labels).observe(value)
else:
metric.observe(value)
except Exception as e:
logger.warning(f"Error recording histogram {name}: {e}")
async def timing(self, name: str, duration: float, labels: Optional[Dict[str, str]] = None):
"""记录时间指标"""
# 使用直方图记录时间
await self.histogram(f"{name}_seconds", duration, labels)
async def flush(self):
"""刷新缓冲的指标到Prometheus"""
async with self._lock:
if not self._buffer:
return
# 处理缓冲区中的指标
for metric in self._buffer:
if metric.metric_type == MetricType.COUNTER:
await self.increment(metric.name, metric.value, metric.labels)
elif metric.metric_type == MetricType.GAUGE:
await self.gauge(metric.name, metric.value, metric.labels)
elif metric.metric_type == MetricType.HISTOGRAM:
await self.histogram(metric.name, metric.value, metric.labels)
# 清空缓冲区
self._buffer.clear()
async def get_stats(self) -> Dict[str, Any]:
"""获取插件统计信息"""
return {
"type": "prometheus",
"metrics_count": len(self._metrics),
"buffer_size": len(self._buffer),
"flush_interval": self.flush_interval,
"batch_size": self.batch_size,
}
def get_metrics(self) -> bytes:
"""
获取Prometheus格式的指标数据
Returns:
Prometheus文本格式的指标
"""
return generate_latest(REGISTRY)
async def shutdown(self):
"""
关闭插件,取消后台任务
这个方法应该在应用关闭时调用
"""
# 禁用插件
self.enabled = False
# 取消并等待后台任务完成
if self._flush_task and not self._flush_task.done():
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# 最后一次刷新缓冲区
await self.flush()
logger.info("Prometheus plugin shutdown complete")
async def cleanup(self):
"""
清理资源(别名方法)
"""
await self.shutdown()

View File

@@ -0,0 +1,15 @@
"""
通知插件
"""
from .base import Notification, NotificationLevel, NotificationPlugin
from .email import EmailNotificationPlugin
from .webhook import WebhookNotificationPlugin
__all__ = [
"NotificationPlugin",
"NotificationLevel",
"Notification",
"WebhookNotificationPlugin",
"EmailNotificationPlugin",
]

View File

@@ -0,0 +1,414 @@
"""
通知插件基类
定义通知的接口和数据结构
"""
import asyncio
import json
from abc import abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional
from src.plugins.common import BasePlugin
class NotificationLevel(Enum):
"""通知级别"""
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
class Notification:
"""通知对象"""
def __init__(
self,
title: str,
message: str,
level: NotificationLevel = NotificationLevel.INFO,
notification_type: Optional[str] = None,
source: Optional[str] = None,
timestamp: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
recipient: Optional[str] = None,
tags: Optional[List[str]] = None,
):
self.title = title
self.message = message
self.level = level
self.notification_type = notification_type or "system"
self.source = source or "aether"
self.timestamp = timestamp or datetime.now(timezone.utc)
self.metadata = metadata or {}
self.recipient = recipient
self.tags = tags or []
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"title": self.title,
"message": self.message,
"level": self.level.value,
"type": self.notification_type,
"source": self.source,
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
"recipient": self.recipient,
"tags": self.tags,
}
def to_json(self) -> str:
"""转换为JSON"""
return json.dumps(self.to_dict(), default=str)
def format_message(self, template: Optional[str] = None) -> str:
"""格式化消息"""
if template:
return template.format(
title=self.title,
message=self.message,
level=self.level.value,
type=self.notification_type,
source=self.source,
timestamp=self.timestamp.isoformat(),
**self.metadata,
)
else:
# 默认格式
return f"[{self.level.value.upper()}] {self.title}\n{self.message}"
class NotificationPlugin(BasePlugin):
"""
通知插件基类
所有通知插件必须实现这个接口
提供统一的重试机制,子类只需实现 _do_send 和 _do_send_batch 方法
"""
def __init__(self, name: str = "notification", config: Dict[str, Any] = None):
# 调用父类初始化设置metadata
super().__init__(
name=name, config=config, description="Notification Plugin", version="1.0.0"
)
self.min_level = NotificationLevel[self.config.get("min_level", "INFO").upper()]
self.batch_size = self.config.get("batch_size", 10)
self.flush_interval = self.config.get("flush_interval", 60) # 秒
self.retry_count = self.config.get("retry_count", 3)
self.retry_delay = self.config.get("retry_delay", 5) # 秒
self.retry_backoff = self.config.get("retry_backoff", 2.0) # 指数退避因子
# 统计信息
self._send_attempts = 0
self._send_successes = 0
self._send_failures = 0
self._retry_total = 0
async def send(self, notification: Notification) -> bool:
"""
发送单个通知(带重试机制)
Args:
notification: 通知对象
Returns:
是否发送成功
"""
if not self.should_send(notification):
return False
self._send_attempts += 1
last_error = None
for attempt in range(self.retry_count):
try:
result = await self._do_send(notification)
if result:
self._send_successes += 1
return True
except Exception as e:
last_error = e
# 如果不是最后一次尝试,等待后重试
if attempt < self.retry_count - 1:
self._retry_total += 1
delay = self.retry_delay * (self.retry_backoff**attempt)
await asyncio.sleep(delay)
# 所有重试都失败
self._send_failures += 1
if last_error:
# 可以在这里记录日志,但不抛出异常
pass
return False
@abstractmethod
async def _do_send(self, notification: Notification) -> bool:
"""
实际发送单个通知(子类实现)
Args:
notification: 通知对象
Returns:
是否发送成功
"""
pass
async def send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
"""
批量发送通知(带重试机制)
Args:
notifications: 通知列表
Returns:
发送结果统计
"""
# 过滤应该发送的通知
to_send = [n for n in notifications if self.should_send(n)]
if not to_send:
return {"total": 0, "sent": 0, "failed": 0}
self._send_attempts += len(to_send)
last_error = None
result = None
for attempt in range(self.retry_count):
try:
result = await self._do_send_batch(to_send)
if result and result.get("sent", 0) == len(to_send):
self._send_successes += result.get("sent", 0)
return result
elif result:
# 部分成功
self._send_successes += result.get("sent", 0)
self._send_failures += result.get("failed", 0)
return result
except Exception as e:
last_error = e
# 如果不是最后一次尝试,等待后重试
if attempt < self.retry_count - 1:
self._retry_total += 1
delay = self.retry_delay * (self.retry_backoff**attempt)
await asyncio.sleep(delay)
# 所有重试都失败
self._send_failures += len(to_send)
return {
"total": len(to_send),
"sent": 0,
"failed": len(to_send),
"error": str(last_error) if last_error else "Unknown error",
}
@abstractmethod
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
"""
实际批量发送通知(子类实现)
Args:
notifications: 通知列表
Returns:
发送结果统计 {"total": int, "sent": int, "failed": int}
"""
pass
def should_send(self, notification: Notification) -> bool:
"""判断是否应该发送通知"""
if not self.enabled:
return False
# 级别过滤
level_values = {level: i for i, level in enumerate(NotificationLevel)}
if level_values[notification.level] < level_values[self.min_level]:
return False
return True
async def send_error(
self,
error: Exception,
context: Optional[Dict[str, Any]] = None,
recipient: Optional[str] = None,
) -> bool:
"""发送错误通知"""
notification = Notification(
title=f"Error: {type(error).__name__}",
message=str(error),
level=NotificationLevel.ERROR,
notification_type="error",
metadata=context or {},
recipient=recipient,
tags=["error", type(error).__name__],
)
return await self.send(notification)
async def send_warning(
self,
title: str,
message: str,
context: Optional[Dict[str, Any]] = None,
recipient: Optional[str] = None,
) -> bool:
"""发送警告通知"""
notification = Notification(
title=title,
message=message,
level=NotificationLevel.WARNING,
notification_type="warning",
metadata=context or {},
recipient=recipient,
tags=["warning"],
)
return await self.send(notification)
async def send_info(
self,
title: str,
message: str,
context: Optional[Dict[str, Any]] = None,
recipient: Optional[str] = None,
) -> bool:
"""发送信息通知"""
notification = Notification(
title=title,
message=message,
level=NotificationLevel.INFO,
notification_type="info",
metadata=context or {},
recipient=recipient,
tags=["info"],
)
return await self.send(notification)
async def send_critical(
self,
title: str,
message: str,
context: Optional[Dict[str, Any]] = None,
recipient: Optional[str] = None,
) -> bool:
"""发送严重通知"""
notification = Notification(
title=title,
message=message,
level=NotificationLevel.CRITICAL,
notification_type="critical",
metadata=context or {},
recipient=recipient,
tags=["critical"],
)
return await self.send(notification)
async def send_usage_alert(
self,
user_id: str,
usage_percent: float,
limit: int,
current: int,
resource_type: str = "tokens",
) -> bool:
"""发送使用量警告"""
level = NotificationLevel.INFO
if usage_percent >= 90:
level = NotificationLevel.CRITICAL
elif usage_percent >= 75:
level = NotificationLevel.WARNING
notification = Notification(
title=f"Usage Alert: {resource_type.capitalize()}",
message=f"User {user_id} has used {usage_percent:.1f}% of their {resource_type} quota ({current}/{limit})",
level=level,
notification_type="usage_alert",
metadata={
"user_id": user_id,
"usage_percent": usage_percent,
"limit": limit,
"current": current,
"resource_type": resource_type,
},
tags=["usage", resource_type],
)
return await self.send(notification)
async def send_provider_status(
self,
provider: str,
status: str,
error: Optional[str] = None,
latency: Optional[float] = None,
) -> bool:
"""发送提供商状态通知"""
level = NotificationLevel.INFO
if status == "down":
level = NotificationLevel.CRITICAL
elif status == "degraded":
level = NotificationLevel.WARNING
message = f"Provider {provider} is {status}"
if error:
message += f": {error}"
if latency:
message += f" (latency: {latency:.2f}s)"
notification = Notification(
title=f"Provider Status: {provider}",
message=message,
level=level,
notification_type="provider_status",
metadata={"provider": provider, "status": status, "error": error, "latency": latency},
tags=["provider", status],
)
return await self.send(notification)
async def get_stats(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典,包含基础重试统计和子类特定统计
"""
base_stats = {
"plugin_name": self.name,
"enabled": self.enabled,
"send_attempts": self._send_attempts,
"send_successes": self._send_successes,
"send_failures": self._send_failures,
"retry_total": self._retry_total,
"success_rate": (
self._send_successes / self._send_attempts * 100 if self._send_attempts > 0 else 0
),
"config": {
"min_level": self.min_level.value,
"retry_count": self.retry_count,
"retry_delay": self.retry_delay,
"retry_backoff": self.retry_backoff,
"batch_size": self.batch_size,
"flush_interval": self.flush_interval,
},
}
# 获取子类特定的统计信息
extra_stats = await self._get_extra_stats()
if extra_stats:
base_stats.update(extra_stats)
return base_stats
async def _get_extra_stats(self) -> Dict[str, Any]:
"""
获取子类特定的统计信息(子类可选重写)
Returns:
额外的统计信息
"""
return {}

View File

@@ -0,0 +1,374 @@
"""
邮件通知插件
通过SMTP发送邮件通知
"""
import asyncio
import json
import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any, Dict, List, Optional
try:
import aiosmtplib
AIOSMTPLIB_AVAILABLE = True
except ImportError:
AIOSMTPLIB_AVAILABLE = False
aiosmtplib = None
from src.core.logger import logger
from .base import Notification, NotificationLevel, NotificationPlugin
class EmailNotificationPlugin(NotificationPlugin):
"""
邮件通知插件
支持HTML和纯文本邮件
"""
def __init__(self, name: str = "email", config: Dict[str, Any] = None):
super().__init__(name, config)
# SMTP配置
self.smtp_host = config.get("smtp_host") if config else None
self.smtp_port = config.get("smtp_port", 587) if config else 587
self.smtp_user = config.get("smtp_user") if config else None
self.smtp_password = config.get("smtp_password") if config else None
self.use_tls = config.get("use_tls", True) if config else True
self.use_ssl = config.get("use_ssl", False) if config else False
# 邮件配置
self.from_email = config.get("from_email") if config else None
self.from_name = (
config.get("from_name", "Aether") if config else "Aether"
)
self.to_emails = config.get("to_emails", []) if config else []
self.cc_emails = config.get("cc_emails", []) if config else []
self.bcc_emails = config.get("bcc_emails", []) if config else []
# 模板配置
self.use_html = config.get("use_html", True) if config else True
self.subject_prefix = (
config.get("subject_prefix", "[Aether]") if config else "[Aether]"
)
# 缓冲配置
self._buffer: List[Notification] = []
self._lock = asyncio.Lock()
self._flush_task = None
# 验证配置
config_errors = []
if not self.smtp_host:
config_errors.append("缺少 smtp_host")
if not self.from_email:
config_errors.append("缺少 from_email")
if not self.to_emails:
config_errors.append("缺少 to_emails")
if config_errors:
self.enabled = False
for error in config_errors:
logger.warning(f"Email 插件配置错误: {error},插件已禁用")
return
# 注意: 不在这里启动刷新任务,因为可能还没有运行的事件循环
# 需要在应用启动后调用 initialize() 方法来启动任务
async def initialize(self) -> bool:
"""
初始化插件(在事件循环运行后调用)
启动后台任务等需要事件循环的操作
Returns:
初始化成功返回 True失败返回 False
"""
if not self.enabled:
# 配置无效,插件被禁用
return False
if self._flush_task is None:
self._start_flush_task()
return True
def _start_flush_task(self):
"""启动定时刷新任务"""
async def flush_loop():
while self.enabled:
await asyncio.sleep(self.flush_interval)
await self.flush()
try:
# 获取当前运行的事件循环
loop = asyncio.get_running_loop()
self._flush_task = loop.create_task(flush_loop())
except RuntimeError:
# 没有运行的事件循环,任务将在 initialize() 中创建
logger.warning("Email 插件刷新任务等待事件循环创建")
pass
def _format_html_email(self, notifications: List[Notification]) -> str:
"""格式化HTML邮件"""
# 颜色映射
color_map = {
NotificationLevel.INFO: "#28a745",
NotificationLevel.WARNING: "#ffc107",
NotificationLevel.ERROR: "#dc3545",
NotificationLevel.CRITICAL: "#721c24",
}
# 构建HTML
html = """
<html>
<head>
<style>
body { font-family: Arial, sans-serif; line-height: 1.6; }
.notification { margin: 20px 0; padding: 15px; border-left: 5px solid; }
.info { border-left-color: #28a745; background-color: #d4edda; }
.warning { border-left-color: #ffc107; background-color: #fff3cd; }
.error { border-left-color: #dc3545; background-color: #f8d7da; }
.critical { border-left-color: #721c24; background-color: #f8d7da; }
.title { font-weight: bold; font-size: 1.2em; margin-bottom: 10px; }
.metadata { margin-top: 10px; font-size: 0.9em; color: #666; }
.timestamp { font-size: 0.8em; color: #999; }
</style>
</head>
<body>
<h2>Notifications from Aether</h2>
"""
for notification in notifications:
level_class = notification.level.value
html += f"""
<div class="notification {level_class}">
<div class="title">{notification.title}</div>
<div class="message">{notification.message}</div>
"""
if notification.metadata:
html += '<div class="metadata">'
for key, value in notification.metadata.items():
html += f"<strong>{key}:</strong> {value}<br>"
html += "</div>"
html += f"""
<div class="timestamp">{notification.timestamp.isoformat()}</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def _format_text_email(self, notifications: List[Notification]) -> str:
"""格式化纯文本邮件"""
lines = ["Notifications from Aether", "=" * 50, ""]
for notification in notifications:
lines.append(f"[{notification.level.value.upper()}] {notification.title}")
lines.append("-" * 40)
lines.append(notification.message)
if notification.metadata:
lines.append("")
for key, value in notification.metadata.items():
lines.append(f" {key}: {value}")
lines.append(f"\nTime: {notification.timestamp.isoformat()}")
lines.append("=" * 50)
lines.append("")
return "\n".join(lines)
async def _send_email_async(self, subject: str, body: str, is_html: bool = True) -> bool:
"""异步发送邮件"""
if AIOSMTPLIB_AVAILABLE:
# 使用异步SMTP
message = MIMEMultipart("alternative")
message["Subject"] = f"{self.subject_prefix} {subject}"
message["From"] = f"{self.from_name} <{self.from_email}>"
message["To"] = ", ".join(self.to_emails)
if self.cc_emails:
message["Cc"] = ", ".join(self.cc_emails)
# 添加内容
if is_html:
message.attach(MIMEText(body, "html"))
else:
message.attach(MIMEText(body, "plain"))
try:
# 发送邮件
if self.use_ssl:
await aiosmtplib.send(
message,
hostname=self.smtp_host,
port=self.smtp_port,
use_tls=True,
username=self.smtp_user,
password=self.smtp_password,
)
else:
await aiosmtplib.send(
message,
hostname=self.smtp_host,
port=self.smtp_port,
start_tls=self.use_tls,
username=self.smtp_user,
password=self.smtp_password,
)
return True
except Exception as e:
logger.error(f"异步邮件发送失败: {e}")
return False
else:
# 使用同步SMTP在线程中运行
return await asyncio.get_event_loop().run_in_executor(
None, self._send_email_sync, subject, body, is_html
)
def _send_email_sync(self, subject: str, body: str, is_html: bool = True) -> bool:
"""同步发送邮件"""
message = MIMEMultipart("alternative")
message["Subject"] = f"{self.subject_prefix} {subject}"
message["From"] = f"{self.from_name} <{self.from_email}>"
message["To"] = ", ".join(self.to_emails)
if self.cc_emails:
message["Cc"] = ", ".join(self.cc_emails)
# 添加内容
if is_html:
message.attach(MIMEText(body, "html"))
else:
message.attach(MIMEText(body, "plain"))
try:
# 连接SMTP服务器
if self.use_ssl:
server = smtplib.SMTP_SSL(self.smtp_host, self.smtp_port)
else:
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
if self.use_tls:
server.starttls()
# 登录
if self.smtp_user and self.smtp_password:
server.login(self.smtp_user, self.smtp_password)
# 发送邮件
all_recipients = self.to_emails + self.cc_emails + self.bcc_emails
server.send_message(message, to_addrs=all_recipients)
server.quit()
return True
except Exception as e:
logger.error(f"同步邮件发送失败: {e}")
return False
async def _do_send(self, notification: Notification) -> bool:
"""
实际发送单个通知
Note: 对于 CRITICAL 级别通知,直接发送;其他级别加入缓冲区
"""
# 添加到缓冲区
async with self._lock:
self._buffer.append(notification)
# 如果是严重通知,立即发送
if notification.level == NotificationLevel.CRITICAL:
return await self._flush_buffer()
# 如果缓冲区满,自动刷新
if len(self._buffer) >= self.batch_size:
return await self._flush_buffer()
return True
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
"""实际批量发送通知"""
if not notifications:
return {"total": 0, "sent": 0, "failed": 0}
# 准备邮件内容
subject = f"Batch Notifications ({len(notifications)} items)"
# 检查是否有严重通知
critical_count = sum(1 for n in notifications if n.level == NotificationLevel.CRITICAL)
if critical_count > 0:
subject = f"[CRITICAL] {subject}"
# 格式化邮件内容
if self.use_html:
body = self._format_html_email(notifications)
else:
body = self._format_text_email(notifications)
# 发送邮件
success = await self._send_email_async(subject, body, self.use_html)
return {
"total": len(notifications),
"sent": len(notifications) if success else 0,
"failed": 0 if success else len(notifications),
}
async def _flush_buffer(self) -> bool:
"""刷新缓冲的通知(内部方法,不带锁)"""
if not self._buffer:
return True
notifications = self._buffer[:]
self._buffer.clear()
# 批量发送(直接调用 _do_send_batch 避免重复统计)
result = await self._do_send_batch(notifications)
return result["failed"] == 0
async def flush(self) -> bool:
"""刷新缓冲的通知"""
async with self._lock:
return await self._flush_buffer()
async def _get_extra_stats(self) -> Dict[str, Any]:
"""获取 Email 特定的统计信息"""
return {
"type": "email",
"smtp_host": self.smtp_host,
"smtp_port": self.smtp_port,
"from_email": self.from_email,
"recipients_count": len(self.to_emails),
"buffer_size": len(self._buffer),
"use_html": self.use_html,
"aiosmtplib_available": AIOSMTPLIB_AVAILABLE,
}
async def close(self):
"""关闭插件"""
# 刷新缓冲
await self.flush()
# 取消刷新任务
if self._flush_task:
self._flush_task.cancel()
def __del__(self):
"""清理资源"""
try:
asyncio.create_task(self.close())
except:
pass

View File

@@ -0,0 +1,309 @@
"""
Webhook通知插件
通过HTTP Webhook发送通知
"""
import asyncio
import hashlib
import hmac
import json
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
try:
import aiohttp
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
aiohttp = None
from src.core.logger import logger
from .base import Notification, NotificationLevel, NotificationPlugin
class WebhookNotificationPlugin(NotificationPlugin):
"""
Webhook通知插件
支持多种Webhook格式Slack, Discord, 通用)
"""
def __init__(self, name: str = "webhook", config: Dict[str, Any] = None):
super().__init__(name, config)
if not AIOHTTP_AVAILABLE:
self.enabled = False
logger.warning("aiohttp not installed, webhook plugin disabled")
return
# Webhook配置
self.webhook_url = config.get("webhook_url") if config else None
self.webhook_type = (
config.get("webhook_type", "generic") if config else "generic"
) # generic, slack, discord, teams
self.secret = config.get("secret") if config else None # 用于签名
self.timeout = config.get("timeout", 30) if config else 30
self.headers = config.get("headers", {}) if config else {}
# 缓冲配置
self._buffer: List[Notification] = []
self._lock = asyncio.Lock()
self._session: Optional[aiohttp.ClientSession] = None
self._flush_task = None
if not self.webhook_url:
self.enabled = False
logger.warning("No webhook URL configured")
return
# 启动刷新任务
self._start_flush_task()
def _start_flush_task(self):
"""启动定时刷新任务"""
async def flush_loop():
while self.enabled:
await asyncio.sleep(self.flush_interval)
await self.flush()
self._flush_task = asyncio.create_task(flush_loop())
async def _get_session(self) -> aiohttp.ClientSession:
"""获取HTTP会话"""
if not self._session:
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout))
return self._session
def _generate_signature(self, payload: str) -> str:
"""生成请求签名"""
if not self.secret:
return ""
# 使用HMAC-SHA256生成签名
signature = hmac.new(
self.secret.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256
).hexdigest()
return signature
def _format_for_slack(self, notification: Notification) -> Dict[str, Any]:
"""格式化为Slack消息"""
# Slack颜色映射
color_map = {
NotificationLevel.INFO: "#36a64f",
NotificationLevel.WARNING: "warning",
NotificationLevel.ERROR: "danger",
NotificationLevel.CRITICAL: "#ff0000",
}
return {
"text": notification.title,
"attachments": [
{
"color": color_map.get(notification.level, "#808080"),
"title": notification.title,
"text": notification.message,
"fields": (
[
{"title": k, "value": str(v), "short": True}
for k, v in notification.metadata.items()
]
if notification.metadata
else []
),
"footer": notification.source,
"ts": int(notification.timestamp.timestamp()),
}
],
}
def _format_for_discord(self, notification: Notification) -> Dict[str, Any]:
"""格式化为Discord消息"""
# Discord颜色映射
color_map = {
NotificationLevel.INFO: 0x00FF00,
NotificationLevel.WARNING: 0xFFA500,
NotificationLevel.ERROR: 0xFF0000,
NotificationLevel.CRITICAL: 0x8B0000,
}
embeds = [
{
"title": notification.title,
"description": notification.message,
"color": color_map.get(notification.level, 0x808080),
"fields": (
[
{"name": k, "value": str(v), "inline": True}
for k, v in notification.metadata.items()
]
if notification.metadata
else []
),
"footer": {"text": notification.source},
"timestamp": notification.timestamp.isoformat(),
}
]
return {"embeds": embeds}
def _format_for_teams(self, notification: Notification) -> Dict[str, Any]:
"""格式化为Microsoft Teams消息"""
# Teams颜色映射
color_map = {
NotificationLevel.INFO: "00ff00",
NotificationLevel.WARNING: "ffa500",
NotificationLevel.ERROR: "ff0000",
NotificationLevel.CRITICAL: "8b0000",
}
facts = (
[{"name": k, "value": str(v)} for k, v in notification.metadata.items()]
if notification.metadata
else []
)
return {
"@type": "MessageCard",
"@context": "https://schema.org/extensions",
"themeColor": color_map.get(notification.level, "808080"),
"title": notification.title,
"text": notification.message,
"sections": [{"facts": facts}] if facts else [],
"summary": notification.title,
}
def _format_payload(self, notification: Notification) -> Dict[str, Any]:
"""根据Webhook类型格式化负载"""
if self.webhook_type == "slack":
return self._format_for_slack(notification)
elif self.webhook_type == "discord":
return self._format_for_discord(notification)
elif self.webhook_type == "teams":
return self._format_for_teams(notification)
else:
# 通用格式
return notification.to_dict()
async def _do_send(self, notification: Notification) -> bool:
"""
实际发送单个通知
Note: 对于 CRITICAL 级别通知,直接发送;其他级别加入缓冲区
"""
# 添加到缓冲区
async with self._lock:
self._buffer.append(notification)
# 如果是严重通知,立即发送
if notification.level == NotificationLevel.CRITICAL:
return await self._flush_buffer()
# 如果缓冲区满,自动刷新
if len(self._buffer) >= self.batch_size:
return await self._flush_buffer()
return True
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
"""实际批量发送通知"""
success_count = 0
failed_count = 0
errors = []
if not notifications:
return {"total": 0, "sent": 0, "failed": 0}
# 批量发送
for notification in notifications:
try:
payload = self._format_payload(notification)
payload_str = json.dumps(payload)
headers = dict(self.headers)
headers["Content-Type"] = "application/json"
# 添加签名
if self.secret:
signature = self._generate_signature(payload_str)
headers["X-Signature"] = signature
headers["X-Timestamp"] = str(int(time.time()))
# 发送请求
session = await self._get_session()
async with session.post(
self.webhook_url, data=payload_str, headers=headers
) as response:
if response.status < 300:
success_count += 1
else:
failed_count += 1
error_text = await response.text()
errors.append(f"HTTP {response.status}: {error_text}")
except Exception as e:
failed_count += 1
errors.append(str(e))
return {
"total": len(notifications),
"sent": success_count,
"failed": failed_count,
"errors": errors,
}
async def _flush_buffer(self) -> bool:
"""刷新缓冲的通知(内部方法,不带锁)"""
if not self._buffer:
return True
notifications = self._buffer[:]
self._buffer.clear()
# 批量发送(直接调用 _do_send_batch 避免重复统计)
result = await self._do_send_batch(notifications)
return result["failed"] == 0
async def flush(self) -> bool:
"""刷新缓冲的通知"""
async with self._lock:
return await self._flush_buffer()
async def _get_extra_stats(self) -> Dict[str, Any]:
"""获取 Webhook 特定的统计信息"""
return {
"type": "webhook",
"webhook_type": self.webhook_type,
"webhook_url": (
self.webhook_url.split("?")[0] if self.webhook_url else None
), # 隐藏查询参数
"buffer_size": len(self._buffer),
"has_secret": bool(self.secret),
}
async def _do_shutdown(self):
"""清理资源"""
await self.close()
async def close(self):
"""关闭插件"""
# 刷新缓冲
await self.flush()
# 取消刷新任务
if self._flush_task:
self._flush_task.cancel()
# 关闭HTTP会话
if self._session:
await self._session.close()
def __del__(self):
"""清理资源"""
try:
asyncio.create_task(self.close())
except:
pass

View File

@@ -0,0 +1,9 @@
"""
速率限制插件模块
"""
from .base import RateLimitResult, RateLimitStrategy
from .sliding_window import SlidingWindowStrategy
from .token_bucket import TokenBucketStrategy
__all__ = ["RateLimitStrategy", "RateLimitResult", "TokenBucketStrategy", "SlidingWindowStrategy"]

View File

@@ -0,0 +1,132 @@
"""
速率限制策略基类
定义速率限制策略的接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..common import BasePlugin, HealthStatus, PluginMetadata
@dataclass
class RateLimitResult:
"""
速率限制检查结果
"""
allowed: bool
remaining: int
reset_at: Optional[datetime] = None
retry_after: Optional[int] = None
message: Optional[str] = None
headers: Optional[Dict[str, str]] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.remaining is not None:
self.headers["X-RateLimit-Remaining"] = str(self.remaining)
if self.reset_at:
self.headers["X-RateLimit-Reset"] = str(int(self.reset_at.timestamp()))
if self.retry_after:
self.headers["Retry-After"] = str(self.retry_after)
class RateLimitStrategy(BasePlugin):
"""
速率限制策略基类
所有速率限制策略必须继承此类
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化速率限制策略
Args:
name: 策略名称
priority: 优先级(数字越大优先级越高)
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖的其他插件
provides: 提供的服务
config: 配置字典
"""
super().__init__(
name=name,
priority=priority,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies,
provides=provides,
config=config,
)
@abstractmethod
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键如用户ID、API Key ID等
**kwargs: 额外参数
Returns:
速率限制检查结果
"""
pass
@abstractmethod
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费配额
Args:
key: 限制键
amount: 消费数量
**kwargs: 额外参数
Returns:
是否成功消费
"""
pass
@abstractmethod
async def reset(self, key: str):
"""
重置限制
Args:
key: 限制键
"""
pass
@abstractmethod
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息字典
"""
pass

View File

@@ -0,0 +1,363 @@
"""
滑动窗口算法速率限制策略
精确的速率限制,不允许突发
WARNING: 多进程环境注意事项
=============================
此插件的窗口状态存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
每个 worker 进程有独立的限流状态,可能导致:
- 实际允许的请求数 = 配置限制 * worker数量
- 限流效果大打折扣
解决方案:
1. 单 worker 模式:适用于低流量场景
2. Redis 共享状态:使用 Redis 实现分布式滑动窗口
3. 使用 token_bucket.py令牌桶策略可以更容易迁移到 Redis
目前项目已有 Redis 依赖src/clients/redis_client.py
建议在生产环境使用 Redis 实现分布式限流。
"""
import asyncio
import time
from collections import deque
from datetime import datetime
from typing import Any, Deque, Dict
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class SlidingWindow:
"""滑动窗口实现"""
def __init__(self, window_size: int, max_requests: int):
"""
初始化滑动窗口
Args:
window_size: 窗口大小(秒)
max_requests: 窗口内最大请求数
"""
self.window_size = window_size
self.max_requests = max_requests
self.requests: Deque[float] = deque()
self.last_access_time: float = time.time()
def _cleanup(self):
"""清理过期的请求记录"""
current_time = time.time()
self.last_access_time = current_time # 更新最后访问时间
cutoff_time = current_time - self.window_size
# 移除窗口外的请求
while self.requests and self.requests[0] < cutoff_time:
self.requests.popleft()
def can_accept(self, amount: int = 1) -> bool:
"""
检查是否可以接受新请求
Args:
amount: 请求数量
Returns:
是否可以接受
"""
self._cleanup()
return len(self.requests) + amount <= self.max_requests
def add_request(self, amount: int = 1) -> bool:
"""
添加请求记录
Args:
amount: 请求数量
Returns:
是否成功添加
"""
if not self.can_accept(amount):
return False
current_time = time.time()
for _ in range(amount):
self.requests.append(current_time)
return True
def get_remaining(self) -> int:
"""获取剩余配额"""
self._cleanup()
return max(0, self.max_requests - len(self.requests))
def get_reset_time(self) -> datetime:
"""获取最早的重置时间"""
self._cleanup()
if not self.requests:
return datetime.now()
# 最早的请求将在window_size秒后过期
oldest_request = self.requests[0]
reset_time = oldest_request + self.window_size
return datetime.fromtimestamp(reset_time)
class SlidingWindowStrategy(RateLimitStrategy):
"""
滑动窗口算法速率限制策略
特点:
- 精确的速率限制
- 不允许突发流量
- 适合需要严格速率控制的场景
- 自动清理长时间不活跃的窗口,防止内存泄漏
"""
# 默认最大缓存窗口数量
DEFAULT_MAX_WINDOWS = 10000
# 默认窗口过期时间(秒)- 超过此时间未访问的窗口将被清理
DEFAULT_WINDOW_EXPIRY = 3600 # 1小时
def __init__(self):
super().__init__("sliding_window")
self.windows: Dict[str, SlidingWindow] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_window_size = 60 # 默认60秒窗口
self.default_max_requests = 100 # 默认100个请求
# 内存管理配置
self.max_windows = self.DEFAULT_MAX_WINDOWS
self.window_expiry = self.DEFAULT_WINDOW_EXPIRY
self._last_cleanup_time: float = time.time()
self._cleanup_interval = 300 # 每5分钟检查一次是否需要清理
def _cleanup_expired_windows(self) -> int:
"""
清理过期的窗口,防止内存泄漏
Returns:
清理的窗口数量
"""
current_time = time.time()
expired_keys = []
for key, window in self.windows.items():
# 检查窗口是否过期(长时间未访问)
if current_time - window.last_access_time > self.window_expiry:
expired_keys.append(key)
# 删除过期窗口
for key in expired_keys:
del self.windows[key]
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的滑动窗口")
return len(expired_keys)
def _evict_lru_windows(self, count: int) -> int:
"""
使用 LRU 策略淘汰最久未使用的窗口
Args:
count: 需要淘汰的数量
Returns:
实际淘汰的数量
"""
if not self.windows or count <= 0:
return 0
# 按最后访问时间排序,淘汰最久未访问的
sorted_keys = sorted(self.windows.keys(), key=lambda k: self.windows[k].last_access_time)
evicted = 0
for key in sorted_keys[:count]:
del self.windows[key]
evicted += 1
if evicted:
logger.warning(f"LRU 淘汰了 {evicted} 个滑动窗口(达到容量上限)")
return evicted
async def _maybe_cleanup(self):
"""检查是否需要执行清理操作"""
current_time = time.time()
# 定期清理过期窗口
if current_time - self._last_cleanup_time > self._cleanup_interval:
self._cleanup_expired_windows()
self._last_cleanup_time = current_time
# 如果超过容量上限,执行 LRU 淘汰
if len(self.windows) >= self.max_windows:
# 淘汰 10% 的窗口
evict_count = max(1, self.max_windows // 10)
self._evict_lru_windows(evict_count)
def _get_window(self, key: str) -> SlidingWindow:
"""
获取或创建滑动窗口
Args:
key: 限制键
Returns:
滑动窗口实例
"""
if key not in self.windows:
# 根据key的不同前缀使用不同的配置
if key.startswith("api_key:"):
window_size = self.config.get("api_key_window_size", self.default_window_size)
max_requests = self.config.get("api_key_max_requests", self.default_max_requests)
elif key.startswith("user:"):
window_size = self.config.get("user_window_size", self.default_window_size)
max_requests = self.config.get("user_max_requests", self.default_max_requests * 2)
else:
window_size = self.default_window_size
max_requests = self.default_max_requests
self.windows[key] = SlidingWindow(window_size, max_requests)
return self.windows[key]
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
Returns:
速率限制检查结果
"""
async with self._lock:
# 检查是否需要清理过期窗口
await self._maybe_cleanup()
window = self._get_window(key)
amount = kwargs.get("amount", 1)
# 检查是否可以接受请求
allowed = window.can_accept(amount)
remaining = window.get_remaining()
reset_at = window.get_reset_time()
retry_after = None
if not allowed:
# 计算需要等待的时间(最早请求过期的时间)
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费配额
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
async with self._lock:
window = self._get_window(key)
success = window.add_request(amount)
if success:
logger.debug(f"滑动窗口请求记录成功")
else:
logger.warning(f"滑动窗口请求被拒绝:超出速率限制")
return success
async def reset(self, key: str):
"""
重置滑动窗口
Args:
key: 限制键
"""
async with self._lock:
if key in self.windows:
window = self.windows[key]
window.requests.clear()
logger.info(f"滑动窗口已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
async with self._lock:
window = self._get_window(key)
window._cleanup() # 先清理过期请求
return {
"strategy": "sliding_window",
"key": key,
"window_size": window.window_size,
"max_requests": window.max_requests,
"current_requests": len(window.requests),
"remaining": window.get_remaining(),
"reset_at": window.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_window_size: API Key的窗口大小
- api_key_max_requests: API Key的最大请求数
- user_window_size: 用户的窗口大小(秒)
- user_max_requests: 用户的最大请求数
- max_windows: 最大缓存窗口数量(防止内存泄漏)
- window_expiry: 窗口过期时间(秒)
- cleanup_interval: 清理检查间隔(秒)
"""
super().configure(config)
self.default_window_size = config.get("default_window_size", self.default_window_size)
self.default_max_requests = config.get("default_max_requests", self.default_max_requests)
self.max_windows = config.get("max_windows", self.max_windows)
self.window_expiry = config.get("window_expiry", self.window_expiry)
self._cleanup_interval = config.get("cleanup_interval", self._cleanup_interval)
def get_memory_stats(self) -> Dict[str, Any]:
"""
获取内存使用统计信息
Returns:
内存使用统计
"""
return {
"total_windows": len(self.windows),
"max_windows": self.max_windows,
"window_expiry": self.window_expiry,
"cleanup_interval": self._cleanup_interval,
"last_cleanup_time": self._last_cleanup_time,
"usage_percent": (
(len(self.windows) / self.max_windows * 100) if self.max_windows > 0 else 0
),
}

View File

@@ -0,0 +1,431 @@
"""令牌桶速率限制策略,支持 Redis 分布式后端"""
import asyncio
import os
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class TokenBucket:
"""令牌桶实现"""
def __init__(self, capacity: int, refill_rate: float):
"""
初始化令牌桶
Args:
capacity: 桶容量(最大令牌数)
refill_rate: 令牌补充速率(每秒)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
def _refill(self):
"""补充令牌"""
now = time.time()
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
if tokens_to_add > 0:
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
def consume(self, amount: int = 1) -> bool:
"""
消费令牌
Args:
amount: 要消费的令牌数
Returns:
是否成功消费
"""
self._refill()
if self.tokens >= amount:
self.tokens -= amount
return True
return False
def get_remaining(self) -> int:
"""获取剩余令牌数"""
self._refill()
return int(self.tokens)
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now()
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
"""
令牌桶算法速率限制策略
特点:
- 允许突发流量
- 平均速率受限
- 适合处理不均匀的流量模式
"""
def __init__(self):
super().__init__("token_bucket")
self.buckets: Dict[str, TokenBucket] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_capacity = 100 # 默认桶容量
self.default_refill_rate = 10 # 默认每秒补充10个令牌
# 可选的 Redis 后端
self._redis_backend: Optional[RedisTokenBucketBackend] = None
self._redis_checked = False
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
"""
获取或创建令牌桶
Args:
key: 限制键
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
Returns:
令牌桶实例
"""
if key not in self.buckets:
# 如果提供了rate_limit参数来自数据库优先使用
if rate_limit is not None:
# rate_limit 是每分钟请求数,转换为令牌桶参数
capacity = rate_limit # 桶容量等于每分钟限制
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
# 否则根据key的不同前缀使用不同的配置
elif key.startswith("api_key:"):
capacity = self.config.get("api_key_capacity", self.default_capacity)
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
elif key.startswith("user:"):
capacity = self.config.get("user_capacity", self.default_capacity * 2)
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
else:
capacity = self.default_capacity
refill_rate = self.default_refill_rate
self.buckets[key] = TokenBucket(capacity, refill_rate)
return self.buckets[key]
def _want_redis_backend(self) -> bool:
return self._backend_mode in {"auto", "redis"}
async def _ensure_backend(self):
if self._redis_checked:
return
self._redis_checked = True
if not self._want_redis_backend():
return
redis_client = get_redis_client_sync()
if redis_client:
self._redis_backend = RedisTokenBucketBackend(redis_client)
logger.info("速率限制改用 Redis 令牌桶后端")
elif self._backend_mode == "redis":
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
Returns:
速率限制检查结果
"""
await self._ensure_backend()
rate_limit = kwargs.get("rate_limit")
amount = kwargs.get("amount", 1)
if self._redis_backend:
return await self._redis_backend.peek(
key=key,
capacity=self._resolve_capacity(key, rate_limit),
refill_rate=self._resolve_refill_rate(key, rate_limit),
amount=amount,
)
async with self._lock:
bucket = self._get_bucket(key, rate_limit)
remaining = bucket.get_remaining()
reset_at = bucket.get_reset_time()
allowed = remaining >= amount
retry_after = None
if not allowed:
tokens_needed = amount - remaining
retry_after = int(tokens_needed / bucket.refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费令牌
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
await self._ensure_backend()
if self._redis_backend:
success, remaining = await self._redis_backend.consume(
key=key,
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
amount=amount,
)
if success:
logger.debug("Redis 令牌消费成功")
else:
logger.warning("Redis 令牌消费失败")
return success
async with self._lock:
bucket = self._get_bucket(key)
success = bucket.consume(amount)
if success:
logger.debug(f"令牌消费成功")
else:
logger.warning(f"令牌消费失败:超出速率限制")
return success
async def reset(self, key: str):
"""
重置令牌桶
Args:
key: 限制键
"""
await self._ensure_backend()
if self._redis_backend:
await self._redis_backend.reset(key)
return
async with self._lock:
if key in self.buckets:
bucket = self.buckets[key]
bucket.tokens = bucket.capacity
bucket.last_refill = time.time()
logger.info(f"令牌桶已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
await self._ensure_backend()
if self._redis_backend:
return await self._redis_backend.get_stats(
key,
capacity=self._resolve_capacity(key),
refill_rate=self._resolve_refill_rate(key),
)
async with self._lock:
bucket = self._get_bucket(key)
return {
"strategy": "token_bucket",
"key": key,
"capacity": bucket.capacity,
"remaining": bucket.get_remaining(),
"refill_rate": bucket.refill_rate,
"reset_at": bucket.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_capacity: API Key的桶容量
- api_key_refill_rate: API Key的令牌补充速率
- user_capacity: 用户的桶容量
- user_refill_rate: 用户的令牌补充速率
"""
super().configure(config)
self.default_capacity = config.get("default_capacity", self.default_capacity)
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
if rate_limit is not None:
return rate_limit
if key.startswith("api_key:"):
return self.config.get("api_key_capacity", self.default_capacity)
if key.startswith("user:"):
return self.config.get("user_capacity", self.default_capacity * 2)
return self.default_capacity
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
if rate_limit is not None:
return rate_limit / 60.0
if key.startswith("api_key:"):
return self.config.get("api_key_refill_rate", self.default_refill_rate)
if key.startswith("user:"):
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
return self.default_refill_rate
class RedisTokenBucketBackend:
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
local tokens = tonumber(data[1])
local last_refill = tonumber(data[2])
if tokens == nil then
tokens = capacity
last_refill = now
end
local delta = math.max(0, now - last_refill)
local refill = delta * refill_rate
tokens = math.min(capacity, tokens + refill)
local allowed = 0
local retry_after = 0
if tokens >= amount then
tokens = tokens - amount
allowed = 1
else
retry_after = math.ceil((amount - tokens) / refill_rate)
end
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
local ttl = math.max(1, math.ceil(capacity / refill_rate))
redis.call('EXPIRE', key, ttl)
return {allowed, tokens, retry_after}
"""
def __init__(self, redis_client):
self.redis = redis_client
self._consume_script = self.redis.register_script(self._SCRIPT)
def _redis_key(self, key: str) -> str:
return f"rate_limit:bucket:{key}"
async def peek(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> RateLimitResult:
bucket_key = self._redis_key(key)
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
tokens = data[0]
last_refill = data[1]
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
delta = max(0.0, time.time() - last_refill_value)
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None
if not allowed:
needed = max(0, amount - remaining)
retry_after = int(needed / refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=int(remaining),
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> Tuple[bool, int]:
result = await self._consume_script(
keys=[self._redis_key(key)],
args=[time.time(), capacity, refill_rate, amount],
)
allowed = bool(result[0])
remaining = int(float(result[1]))
return allowed, remaining
async def reset(self, key: str):
await self.redis.delete(self._redis_key(key))
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
tokens = data[0]
timestamp = data[1]
return {
"strategy": "token_bucket",
"key": key,
"capacity": capacity,
"remaining": float(tokens) if tokens else capacity,
"refill_rate": refill_rate,
"last_refill": float(timestamp) if timestamp else time.time(),
"backend": "redis",
}

View File

@@ -0,0 +1,9 @@
"""
Token计数插件
"""
from .base import TokenCounterPlugin, TokenUsage
from .claude_counter import ClaudeTokenCounterPlugin
from .tiktoken_counter import TiktokenCounterPlugin
__all__ = ["TokenCounterPlugin", "TokenUsage", "TiktokenCounterPlugin", "ClaudeTokenCounterPlugin"]

170
src/plugins/token/base.py Normal file
View File

@@ -0,0 +1,170 @@
"""
Token计数插件基类
定义Token计数的接口
"""
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from src.plugins.common import BasePlugin
@dataclass
class TokenUsage:
"""令牌使用情况"""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
cache_read_tokens: int = 0 # Claude缓存读取
cache_write_tokens: int = 0 # Claude缓存写入
reasoning_tokens: int = 0 # OpenAI o1推理令牌
def __add__(self, other: "TokenUsage") -> "TokenUsage":
"""令牌使用相加"""
return TokenUsage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
total_tokens=self.total_tokens + other.total_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
)
def to_dict(self) -> Dict[str, int]:
"""转换为字典"""
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
}
class TokenCounterPlugin(BasePlugin):
"""
Token计数插件基类
支持不同模型的Token计数
"""
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
# 调用父类初始化设置metadata
super().__init__(
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
)
self.supported_models = self.config.get("supported_models", [])
self.default_model = self.config.get("default_model")
@abstractmethod
def supports_model(self, model: str) -> bool:
"""检查是否支持指定模型"""
pass
@abstractmethod
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
"""计算文本的Token数量"""
pass
@abstractmethod
async def count_messages(
self, messages: List[Dict[str, Any]], model: Optional[str] = None
) -> int:
"""计算消息列表的Token数量"""
pass
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
"""计算请求的Token数量"""
model = model or request.get("model") or self.default_model
messages = request.get("messages", [])
return await self.count_messages(messages, model)
async def count_response(
self, response: Dict[str, Any], model: Optional[str] = None
) -> TokenUsage:
"""从响应中提取Token使用情况"""
usage = response.get("usage", {})
# OpenAI格式
if "prompt_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
"reasoning_tokens", 0
),
)
# Claude格式
elif "input_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("input_tokens", 0),
output_tokens=usage.get("output_tokens", 0),
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
)
return TokenUsage()
async def estimate_cost(
self, usage: TokenUsage, model: str, provider: Optional[str] = None
) -> Dict[str, float]:
"""估算使用成本"""
# 默认价格表每1M tokens的价格
pricing = self.config.get("pricing", {})
# 获取模型价格
model_pricing = pricing.get(model, {})
if not model_pricing:
# 尝试使用前缀匹配
for model_prefix, price_info in pricing.items():
if model.startswith(model_prefix):
model_pricing = price_info
break
if not model_pricing:
return {"error": "No pricing information available"}
# 计算成本
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
# 缓存成本Claude特有
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
"cache_write", 0
)
# 推理成本OpenAI o1特有
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
return {
"input_cost": round(input_cost, 6),
"output_cost": round(output_cost, 6),
"cache_read_cost": round(cache_read_cost, 6),
"cache_write_cost": round(cache_write_cost, 6),
"reasoning_cost": round(reasoning_cost, 6),
"total_cost": round(total_cost, 6),
"currency": "USD",
}
@abstractmethod
async def get_model_info(self, model: str) -> Dict[str, Any]:
"""获取模型信息"""
pass
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"type": self.name,
"enabled": self.enabled,
"supported_models": self.supported_models,
"default_model": self.default_model,
}

View File

@@ -0,0 +1,273 @@
"""
Claude Token计数插件
专门为Claude模型设计的Token计数器
"""
import json
import re
from typing import Any, Dict, List, Optional
from .base import TokenCounterPlugin, TokenUsage
class ClaudeTokenCounterPlugin(TokenCounterPlugin):
"""
Claude专用Token计数插件
使用简化的估算方法
"""
# Claude模型信息
CLAUDE_MODELS = {
"claude-3-5-sonnet-20241022": {
"max_tokens": 200000,
"max_output": 8192,
"chars_per_token": 3.5, # 平均字符/token比例
},
"claude-3-5-haiku-20241022": {
"max_tokens": 200000,
"max_output": 8192,
"chars_per_token": 3.5,
},
"claude-3-opus-20240229": {
"max_tokens": 200000,
"max_output": 4096,
"chars_per_token": 3.5,
},
"claude-3-sonnet-20240229": {
"max_tokens": 200000,
"max_output": 4096,
"chars_per_token": 3.5,
},
"claude-3-haiku-20240307": {
"max_tokens": 200000,
"max_output": 4096,
"chars_per_token": 3.5,
},
# 旧版模型
"claude-2.1": {
"max_tokens": 100000,
"max_output": 4096,
"chars_per_token": 4,
},
"claude-2.0": {
"max_tokens": 100000,
"max_output": 4096,
"chars_per_token": 4,
},
"claude-instant-1.2": {
"max_tokens": 100000,
"max_output": 4096,
"chars_per_token": 4,
},
}
def __init__(self, name: str = "claude", config: Dict[str, Any] = None):
super().__init__(name, config)
# 价格表每1M tokens的价格 USD
default_pricing = {
"claude-3-5-sonnet": {
"input": 3,
"output": 15,
"cache_write": 3.75, # 缓存写入
"cache_read": 0.30, # 缓存读取
},
"claude-3-5-haiku": {
"input": 0.8,
"output": 4,
"cache_write": 1,
"cache_read": 0.08,
},
"claude-3-opus": {
"input": 15,
"output": 75,
"cache_write": 18.75,
"cache_read": 1.50,
},
"claude-3-sonnet": {
"input": 3,
"output": 15,
"cache_write": 3.75,
"cache_read": 0.30,
},
"claude-3-haiku": {
"input": 0.25,
"output": 1.25,
"cache_write": 0.30,
"cache_read": 0.03,
},
"claude-2.1": {
"input": 8,
"output": 24,
},
"claude-2.0": {
"input": 8,
"output": 24,
},
"claude-instant": {
"input": 0.8,
"output": 2.4,
},
}
self.config["pricing"] = (
config.get("pricing", default_pricing) if config else default_pricing
)
def supports_model(self, model: str) -> bool:
"""检查是否支持指定模型"""
# 支持所有Claude模型
return "claude" in model.lower()
def _estimate_tokens_from_text(self, text: str, model: str) -> int:
"""从文本估算Token数量"""
# 获取模型信息
model_info = None
for model_name, info in self.CLAUDE_MODELS.items():
if model.startswith(model_name.split("-20")[0]): # 匹配基本名称
model_info = info
break
if not model_info:
# 默认值
model_info = {"chars_per_token": 3.5}
# 基本估算
chars_per_token = model_info["chars_per_token"]
# 考虑不同语言的特点
# 检测是否包含中文/日文/韩文
cjk_pattern = re.compile(r"[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]")
cjk_count = len(cjk_pattern.findall(text))
if cjk_count > len(text) * 0.3: # 超过30%是CJK字符
# CJK字符通常每个字符1-2个token
return int(len(text) / 1.5)
else:
# 英文和其他语言
# 考虑空格和标点
word_count = len(text.split())
# 平均每个单词1.3个token
token_by_words = int(word_count * 1.3)
# 平均每个字符chars_per_token
token_by_chars = int(len(text) / chars_per_token)
# 取两者的平均
return (token_by_words + token_by_chars) // 2
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
"""计算文本的Token数量"""
if not self.enabled:
return 0
model = model or self.default_model or "claude-3-5-sonnet-20241022"
return self._estimate_tokens_from_text(text, model)
async def count_messages(
self, messages: List[Dict[str, Any]], model: Optional[str] = None
) -> int:
"""计算消息列表的Token数量"""
if not self.enabled:
return 0
model = model or self.default_model or "claude-3-5-sonnet-20241022"
total_tokens = 0
for message in messages:
# 角色token约3 tokens
total_tokens += 3
# 内容token
content = message.get("content")
if content:
if isinstance(content, str):
total_tokens += self._estimate_tokens_from_text(content, model)
elif isinstance(content, list):
# 处理多模态内容
for item in content:
if item.get("type") == "text":
text = item.get("text", "")
total_tokens += self._estimate_tokens_from_text(text, model)
elif item.get("type") == "image":
# Claude图像处理
# 基础: 1,600 tokens
# 每个256x256的tile: 280 tokens
# 简化估算
total_tokens += 2000 # 平均估算
elif item.get("type") == "tool_use":
# 工具使用
tool_name = item.get("name", "")
tool_input = item.get("input", {})
total_tokens += self._estimate_tokens_from_text(tool_name, model)
total_tokens += self._estimate_tokens_from_text(
json.dumps(tool_input), model
)
elif item.get("type") == "tool_result":
# 工具结果
tool_content = item.get("content", "")
if isinstance(tool_content, str):
total_tokens += self._estimate_tokens_from_text(tool_content, model)
# 添加系统提示的token如果有
if messages and messages[0].get("role") == "system":
# 系统提示通常会有额外的开销
total_tokens += 10
return total_tokens
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
"""计算请求的Token数量"""
model = model or request.get("model") or self.default_model
messages = request.get("messages", [])
total = await self.count_messages(messages, model)
# 考虑系统提示
system = request.get("system")
if system:
total += self._estimate_tokens_from_text(system, model)
total += 5 # 系统提示的额外开销
return total
async def get_model_info(self, model: str) -> Dict[str, Any]:
"""获取模型信息"""
info = {"model": model, "supported": self.supports_model(model)}
if self.supports_model(model):
# 查找匹配的模型信息
model_info = None
for model_name, m_info in self.CLAUDE_MODELS.items():
if model.startswith(model_name.split("-20")[0]):
model_info = m_info
info["model_name"] = model_name
break
if model_info:
info.update(
{
"max_tokens": model_info["max_tokens"],
"max_output": model_info["max_output"],
"chars_per_token": model_info["chars_per_token"],
"supports_vision": "claude-3" in model,
"supports_tools": "claude-3" in model,
"supports_cache": "claude-3" in model,
}
)
# 添加价格信息
pricing = self.config.get("pricing", {})
for price_key in pricing:
if model.startswith(price_key):
info["pricing"] = pricing[price_key]
break
return info
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = await super().get_stats()
stats.update(
{
"estimation_method": "character_based",
"supported_models_count": len(self.CLAUDE_MODELS),
}
)
return stats

View File

@@ -0,0 +1,269 @@
"""
Tiktoken Token计数插件
支持OpenAI和其他使用tiktoken的模型
"""
import json
from typing import Any, Dict, List, Optional
from src.core.logger import logger
from .base import TokenCounterPlugin, TokenUsage
# 尝试导入tiktoken
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
tiktoken = None
class TiktokenCounterPlugin(TokenCounterPlugin):
"""
使用tiktoken库计算Token数量
支持OpenAI模型和其他兼容模型
"""
# 模型编码映射
MODEL_ENCODINGS = {
# GPT-4 系列
"gpt-4": "cl100k_base",
"gpt-4-32k": "cl100k_base",
"gpt-4-turbo": "cl100k_base",
"gpt-4-turbo-preview": "cl100k_base",
"gpt-4o": "o200k_base",
"gpt-4o-mini": "o200k_base",
# GPT-3.5 系列
"gpt-3.5-turbo": "cl100k_base",
"gpt-3.5-turbo-16k": "cl100k_base",
# 旧模型
"text-davinci-003": "p50k_base",
"text-davinci-002": "p50k_base",
"code-davinci-002": "p50k_base",
# Embeddings
"text-embedding-ada-002": "cl100k_base",
"text-embedding-3-small": "cl100k_base",
"text-embedding-3-large": "cl100k_base",
}
# 每个消息的额外Token数
MESSAGE_OVERHEAD = {
"gpt-3.5-turbo": 4, # 每条消息
"gpt-4": 3,
"gpt-4-turbo": 3,
"gpt-4o": 3,
"gpt-4o-mini": 3,
}
def __init__(self, name: str = "tiktoken", config: Dict[str, Any] = None):
super().__init__(name, config)
if not TIKTOKEN_AVAILABLE:
self.enabled = False
logger.warning("tiktoken not installed, plugin disabled")
return
# 缓存编码器
self._encoders = {}
# 价格表每1M tokens的价格 USD
default_pricing = {
"gpt-4o": {"input": 2.5, "output": 10},
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
"gpt-4-turbo": {"input": 10, "output": 30},
"gpt-4": {"input": 30, "output": 60},
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
"o1-preview": {"input": 15, "output": 60, "reasoning": 60},
"o1-mini": {"input": 3, "output": 12, "reasoning": 12},
}
self.config["pricing"] = (
config.get("pricing", default_pricing) if config else default_pricing
)
def _get_encoder(self, model: str) -> Any:
"""获取模型的编码器"""
if model in self._encoders:
return self._encoders[model]
# 获取编码名称
encoding_name = None
# 完全匹配
if model in self.MODEL_ENCODINGS:
encoding_name = self.MODEL_ENCODINGS[model]
else:
# 前缀匹配
for model_prefix, enc_name in self.MODEL_ENCODINGS.items():
if model.startswith(model_prefix):
encoding_name = enc_name
break
# 如果找不到,尝试使用模型名称
if not encoding_name:
try:
encoder = tiktoken.encoding_for_model(model)
self._encoders[model] = encoder
return encoder
except:
# 默认使用cl100k_base
encoding_name = "cl100k_base"
# 创建编码器
encoder = tiktoken.get_encoding(encoding_name)
self._encoders[model] = encoder
return encoder
def supports_model(self, model: str) -> bool:
"""检查是否支持指定模型"""
# 支持所有OpenAI模型和一些兼容模型
openai_models = ["gpt-4", "gpt-3.5", "text-davinci", "text-embedding", "code-davinci", "o1"]
return any(model.startswith(prefix) for prefix in openai_models)
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
"""计算文本的Token数量"""
if not self.enabled:
return 0
model = model or self.default_model or "gpt-3.5-turbo"
encoder = self._get_encoder(model)
try:
tokens = encoder.encode(text)
return len(tokens)
except Exception as e:
logger.warning(f"Error counting tokens: {e}")
# 简单估算: 平均每个字符0.75个token
return int(len(text) * 0.75)
async def count_messages(
self, messages: List[Dict[str, Any]], model: Optional[str] = None
) -> int:
"""计算消息列表的Token数量"""
if not self.enabled:
return 0
model = model or self.default_model or "gpt-3.5-turbo"
encoder = self._get_encoder(model)
# 获取每条消息的额外token数
msg_overhead = self.MESSAGE_OVERHEAD.get(model, 3)
total_tokens = 0
for message in messages:
# 每条消息的基本token
total_tokens += msg_overhead
# 角色token
role = message.get("role", "")
if role:
total_tokens += len(encoder.encode(role))
# 内容token
content = message.get("content")
if content:
if isinstance(content, str):
total_tokens += len(encoder.encode(content))
elif isinstance(content, list):
# 处理多模态内容
for item in content:
if item.get("type") == "text":
text = item.get("text", "")
total_tokens += len(encoder.encode(text))
elif item.get("type") == "image_url":
# 图像的token计算更复杂这里简化处理
# 低分辨率: 85 tokens, 高分辨率: 170 tokens
detail = item.get("image_url", {}).get("detail", "auto")
total_tokens += 170 if detail == "high" else 85
# 名称token
name = message.get("name")
if name:
total_tokens += len(encoder.encode(name)) - 1 # name会减去1个token
# 工具调用
tool_calls = message.get("tool_calls")
if tool_calls:
for tool_call in tool_calls:
# 工具ID
if "id" in tool_call:
total_tokens += len(encoder.encode(tool_call["id"]))
# 函数信息
function = tool_call.get("function", {})
if "name" in function:
total_tokens += len(encoder.encode(function["name"]))
if "arguments" in function:
total_tokens += len(encoder.encode(function["arguments"]))
# 添加固定的结束标记
total_tokens += 3
return total_tokens
async def get_model_info(self, model: str) -> Dict[str, Any]:
"""获取模型信息"""
info = {"model": model, "supported": self.supports_model(model)}
if self.supports_model(model):
# 获取编码信息
encoder = self._get_encoder(model)
encoding_name = None
# 找到编码名称
for m, enc in self.MODEL_ENCODINGS.items():
if model.startswith(m):
encoding_name = enc
break
info.update(
{
"encoding": encoding_name or "unknown",
"vocab_size": encoder.n_vocab if hasattr(encoder, "n_vocab") else None,
"max_tokens": self._get_max_tokens(model),
"message_overhead": self.MESSAGE_OVERHEAD.get(model, 3),
}
)
# 添加价格信息
pricing = self.config.get("pricing", {})
if model in pricing:
info["pricing"] = pricing[model]
return info
def _get_max_tokens(self, model: str) -> int:
"""获取模型的最大token数"""
max_tokens_map = {
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-turbo": 128000,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"o1-preview": 128000,
"o1-mini": 128000,
}
# 完全匹配
if model in max_tokens_map:
return max_tokens_map[model]
# 前缀匹配
for model_prefix, max_tokens in max_tokens_map.items():
if model.startswith(model_prefix):
return max_tokens
# 默认值
return 4096
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = await super().get_stats()
stats.update(
{"encoders_cached": len(self._encoders), "tiktoken_available": TIKTOKEN_AVAILABLE}
)
return stats