mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
Initial commit
This commit is contained in:
3
src/plugins/__init__.py
Normal file
3
src/plugins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
插件系统基础模块
|
||||
"""
|
||||
8
src/plugins/auth/__init__.py
Normal file
8
src/plugins/auth/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
认证插件模块
|
||||
"""
|
||||
|
||||
from .api_key import ApiKeyAuthPlugin
|
||||
from .base import AuthContext, AuthPlugin
|
||||
|
||||
__all__ = ["AuthPlugin", "AuthContext", "ApiKeyAuthPlugin"]
|
||||
96
src/plugins/auth/api_key.py
Normal file
96
src/plugins/auth/api_key.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
API Key认证插件
|
||||
支持从header中提取API Key进行认证
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
from .base import AuthContext, AuthPlugin
|
||||
|
||||
|
||||
|
||||
class ApiKeyAuthPlugin(AuthPlugin):
|
||||
"""
|
||||
API Key认证插件
|
||||
支持从x-api-key header或Authorization Bearer token中提取API Key
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="api_key", priority=10)
|
||||
|
||||
def get_credentials(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
从请求头中提取API Key
|
||||
|
||||
支持两种方式:
|
||||
1. x-api-key: <key>
|
||||
2. Authorization: Bearer <key>
|
||||
"""
|
||||
# 尝试从x-api-key header获取
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
return api_key
|
||||
|
||||
# 尝试从Authorization header获取
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.replace("Bearer ", "")
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
|
||||
"""
|
||||
使用API Key进行认证
|
||||
"""
|
||||
# 提取API Key
|
||||
api_key = self.get_credentials(request)
|
||||
if not api_key:
|
||||
logger.debug("未找到API Key凭据")
|
||||
return None
|
||||
|
||||
# 认证API Key
|
||||
auth_result = AuthService.authenticate_api_key(db, api_key)
|
||||
if not auth_result:
|
||||
logger.warning("API Key认证失败")
|
||||
return None
|
||||
|
||||
user, api_key_obj = auth_result
|
||||
|
||||
# 检查用户配额或独立Key余额
|
||||
quota_ok, message = UsageService.check_user_quota(db, user, api_key=api_key_obj)
|
||||
|
||||
# 创建认证上下文
|
||||
auth_context = AuthContext(
|
||||
user_id=user.id,
|
||||
user_name=user.username,
|
||||
api_key_id=api_key_obj.id,
|
||||
api_key_name=api_key_obj.name if hasattr(api_key_obj, "name") else None,
|
||||
permissions={
|
||||
"can_use_api": quota_ok,
|
||||
"is_admin": user.is_admin if hasattr(user, "is_admin") else False,
|
||||
"is_standalone_key": api_key_obj.is_standalone, # 标记是否为独立余额Key
|
||||
},
|
||||
quota_info={
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"remaining_usd": None if user.quota_usd is None else user.quota_usd - user.used_usd,
|
||||
"quota_ok": quota_ok,
|
||||
"message": message,
|
||||
},
|
||||
metadata={
|
||||
"auth_method": "api_key",
|
||||
"client_ip": request.client.host if request.client else "unknown",
|
||||
"is_standalone": api_key_obj.is_standalone, # 在metadata中也保存一份
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("API Key认证成功")
|
||||
|
||||
return auth_context
|
||||
120
src/plugins/auth/base.py
Normal file
120
src/plugins/auth/base.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
认证插件基类
|
||||
定义认证插件的接口和认证上下文
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..common import BasePlugin, HealthStatus, PluginMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""
|
||||
认证上下文
|
||||
包含认证后的用户信息和权限
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
user_name: str
|
||||
api_key_id: Optional[int] = None
|
||||
api_key_name: Optional[str] = None
|
||||
permissions: Dict[str, bool] = None
|
||||
quota_info: Dict[str, Any] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.permissions is None:
|
||||
self.permissions = {}
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class AuthPlugin(BasePlugin):
|
||||
"""
|
||||
认证插件基类
|
||||
所有认证插件必须继承此类并实现authenticate方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化认证插件
|
||||
|
||||
Args:
|
||||
name: 插件名称
|
||||
priority: 优先级(数字越大优先级越高)
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖的其他插件
|
||||
provides: 提供的服务
|
||||
config: 配置字典
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
priority=priority,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies,
|
||||
provides=provides,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
|
||||
"""
|
||||
执行认证
|
||||
|
||||
Args:
|
||||
request: FastAPI请求对象
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
成功返回AuthContext,失败返回None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
从请求中提取认证凭据
|
||||
|
||||
Args:
|
||||
request: FastAPI请求对象
|
||||
|
||||
Returns:
|
||||
认证凭据字符串,如果没有找到返回None
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_applicable(self, request: Request) -> bool:
|
||||
"""
|
||||
检查此插件是否适用于当前请求
|
||||
|
||||
Args:
|
||||
request: FastAPI请求对象
|
||||
|
||||
Returns:
|
||||
如果插件适用返回True
|
||||
"""
|
||||
# 默认情况下,如果能提取到凭据就适用
|
||||
return self.get_credentials(request) is not None
|
||||
103
src/plugins/auth/jwt.py
Normal file
103
src/plugins/auth/jwt.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
JWT认证插件
|
||||
支持JWT Bearer token认证
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
from src.services.auth.service import AuthService
|
||||
|
||||
from .base import AuthContext, AuthPlugin
|
||||
|
||||
|
||||
|
||||
class JwtAuthPlugin(AuthPlugin):
|
||||
"""
|
||||
JWT认证插件
|
||||
支持从Authorization Bearer header中提取JWT token进行认证
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="jwt", priority=20) # 高优先级,优先于API Key
|
||||
|
||||
def get_credentials(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
从Authorization header中提取JWT token
|
||||
|
||||
支持格式: Authorization: Bearer <token>
|
||||
"""
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
async def authenticate(self, request: Request, db: Session) -> Optional[AuthContext]:
|
||||
"""
|
||||
使用JWT token进行认证
|
||||
"""
|
||||
# 提取JWT token
|
||||
token = self.get_credentials(request)
|
||||
if not token:
|
||||
logger.debug("未找到JWT token")
|
||||
return None
|
||||
|
||||
# 记录认证尝试的详细信息
|
||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
|
||||
|
||||
try:
|
||||
# 验证JWT token
|
||||
payload = AuthService.verify_token(token)
|
||||
logger.debug(f"JWT token验证成功, payload: {payload}")
|
||||
|
||||
# 从payload中提取用户信息
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
logger.warning("JWT token中缺少用户ID")
|
||||
return None
|
||||
|
||||
logger.debug(f"从JWT提取user_id: {user_id}, 类型: {type(user_id)}")
|
||||
|
||||
# 从数据库获取用户信息
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
logger.warning(f"JWT认证失败 - 用户不存在: {user_id}")
|
||||
return None
|
||||
|
||||
logger.debug(f"找到用户: {user.email}, is_active: {user.is_active}")
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(f"JWT认证失败 - 用户已禁用: {user.email}")
|
||||
return None
|
||||
|
||||
# 创建认证上下文
|
||||
auth_context = AuthContext(
|
||||
user_id=user.id,
|
||||
user_name=user.username,
|
||||
permissions={"can_use_api": True, "is_admin": user.role.value == "admin"},
|
||||
quota_info={
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"remaining_usd": (
|
||||
None if user.quota_usd is None else user.quota_usd - user.used_usd
|
||||
),
|
||||
"quota_ok": True, # JWT用户通常已经通过前端验证
|
||||
},
|
||||
metadata={
|
||||
"auth_method": "jwt",
|
||||
"client_ip": request.client.host if request.client else "unknown",
|
||||
"token_exp": payload.get("exp"),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("JWT认证成功")
|
||||
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"JWT认证失败: {str(e)}")
|
||||
return None
|
||||
5
src/plugins/cache/__init__.py
vendored
Normal file
5
src/plugins/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
"""缓存插件包"""
|
||||
|
||||
from .base import CachePlugin
|
||||
|
||||
__all__ = ["CachePlugin"]
|
||||
218
src/plugins/cache/base.py
vendored
Normal file
218
src/plugins/cache/base.py
vendored
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
缓存插件基类
|
||||
定义缓存插件的接口
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..common import BasePlugin, HealthStatus, PluginMetadata
|
||||
|
||||
|
||||
class CachePlugin(BasePlugin):
|
||||
"""
|
||||
缓存插件基类
|
||||
所有缓存插件必须继承此类并实现相关方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化缓存插件
|
||||
|
||||
Args:
|
||||
name: 插件名称
|
||||
priority: 优先级
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖列表
|
||||
provides: 提供服务列表
|
||||
config: 配置字典
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
priority=priority,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies,
|
||||
provides=provides,
|
||||
config=config,
|
||||
)
|
||||
self.default_ttl = self.config.get("default_ttl", 3600) # 默认1小时
|
||||
self.max_size = self.config.get("max_size", 1000) # 最大缓存项数
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存值,如果不存在返回None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
设置缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl: 过期时间(秒),None使用默认值
|
||||
|
||||
Returns:
|
||||
是否成功设置
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存项
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查缓存项是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
清空所有缓存
|
||||
|
||||
Returns:
|
||||
是否成功清空
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
批量获取缓存值
|
||||
|
||||
Args:
|
||||
keys: 缓存键列表
|
||||
|
||||
Returns:
|
||||
键值对字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_many(self, items: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
批量设置缓存值
|
||||
|
||||
Args:
|
||||
items: 键值对字典
|
||||
ttl: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功设置
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
pass
|
||||
|
||||
def generate_key(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
生成缓存键
|
||||
|
||||
Args:
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
缓存键字符串
|
||||
"""
|
||||
# 创建一个稳定的键
|
||||
key_parts = [str(arg) for arg in args]
|
||||
key_parts.extend([f"{k}:{v}" for k, v in sorted(kwargs.items())])
|
||||
key_string = "|".join(key_parts)
|
||||
|
||||
# 如果键太长,使用哈希
|
||||
if len(key_string) > 250:
|
||||
hash_obj = hashlib.md5(key_string.encode())
|
||||
return f"{self.name}:{hash_obj.hexdigest()}"
|
||||
|
||||
return f"{self.name}:{key_string}"
|
||||
|
||||
def serialize(self, value: Any) -> str:
|
||||
"""
|
||||
序列化值
|
||||
|
||||
Args:
|
||||
value: 要序列化的值
|
||||
|
||||
Returns:
|
||||
序列化后的字符串
|
||||
"""
|
||||
return json.dumps(value, default=str)
|
||||
|
||||
def deserialize(self, value: str) -> Any:
|
||||
"""
|
||||
反序列化值
|
||||
|
||||
Args:
|
||||
value: 序列化的字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的值
|
||||
"""
|
||||
return json.loads(value)
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置插件
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
"""
|
||||
super().configure(config)
|
||||
self.default_ttl = config.get("default_ttl", self.default_ttl)
|
||||
self.max_size = config.get("max_size", self.max_size)
|
||||
195
src/plugins/cache/memory.py
vendored
Normal file
195
src/plugins/cache/memory.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
内存缓存插件
|
||||
基于Python字典的简单内存缓存实现
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import CachePlugin
|
||||
|
||||
|
||||
class MemoryCachePlugin(CachePlugin):
|
||||
"""
|
||||
内存缓存插件
|
||||
使用OrderedDict实现LRU缓存
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "memory", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
self._cache: OrderedDict = OrderedDict()
|
||||
self._expiry: Dict[str, float] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._evictions = 0
|
||||
self._cleanup_task = None
|
||||
self._cleanup_interval = 60 # 默认值
|
||||
|
||||
# 启动清理任务
|
||||
if config is not None:
|
||||
self._cleanup_interval = config.get("cleanup_interval", 60)
|
||||
|
||||
try:
|
||||
self._start_cleanup_task()
|
||||
except:
|
||||
pass # 忽略事件循环错误
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动后台清理任务"""
|
||||
|
||||
async def cleanup_loop():
|
||||
while self.enabled:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
self._cleanup_task = loop.create_task(cleanup_loop())
|
||||
|
||||
async def _cleanup_expired(self):
|
||||
"""清理过期的缓存项"""
|
||||
now = time.time()
|
||||
expired_keys = []
|
||||
|
||||
with self._lock:
|
||||
for key, expiry in self._expiry.items():
|
||||
if expiry < now:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
self._cache.pop(key, None)
|
||||
self._expiry.pop(key, None)
|
||||
self._evictions += 1
|
||||
|
||||
def _check_size(self):
|
||||
"""检查并维护缓存大小限制"""
|
||||
if len(self._cache) >= self.max_size:
|
||||
# 删除最老的项(LRU)
|
||||
key = next(iter(self._cache))
|
||||
self._cache.pop(key)
|
||||
self._expiry.pop(key, None)
|
||||
self._evictions += 1
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
with self._lock:
|
||||
# 检查是否过期
|
||||
if key in self._expiry:
|
||||
if self._expiry[key] < time.time():
|
||||
# 已过期,删除
|
||||
self._cache.pop(key, None)
|
||||
self._expiry.pop(key)
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# 获取值并更新访问顺序(LRU)
|
||||
if key in self._cache:
|
||||
value = self._cache.pop(key)
|
||||
self._cache[key] = value # 移到末尾
|
||||
self._hits += 1
|
||||
return self.deserialize(value) if isinstance(value, str) else value
|
||||
else:
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""设置缓存值"""
|
||||
with self._lock:
|
||||
# 检查大小限制
|
||||
if key not in self._cache:
|
||||
self._check_size()
|
||||
|
||||
# 序列化值
|
||||
if not isinstance(value, str):
|
||||
value = self.serialize(value)
|
||||
|
||||
# 设置值
|
||||
self._cache[key] = value
|
||||
self._cache.move_to_end(key) # 移到末尾(最新)
|
||||
|
||||
# 设置过期时间
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
if ttl > 0:
|
||||
self._expiry[key] = time.time() + ttl
|
||||
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""删除缓存项"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
self._cache.pop(key)
|
||||
self._expiry.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查缓存项是否存在"""
|
||||
with self._lock:
|
||||
# 检查是否过期
|
||||
if key in self._expiry:
|
||||
if self._expiry[key] < time.time():
|
||||
# 已过期,删除
|
||||
self._cache.pop(key, None)
|
||||
self._expiry.pop(key)
|
||||
return False
|
||||
return key in self._cache
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""清空所有缓存"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._expiry.clear()
|
||||
return True
|
||||
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""批量获取缓存值"""
|
||||
result = {}
|
||||
for key in keys:
|
||||
value = await self.get(key)
|
||||
if value is not None:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def set_many(self, items: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||||
"""批量设置缓存值"""
|
||||
success = True
|
||||
for key, value in items.items():
|
||||
if not await self.set(key, value, ttl):
|
||||
success = False
|
||||
return success
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = self._hits / total_requests if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"type": "memory",
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": hit_rate,
|
||||
"evictions": self._evictions,
|
||||
"cleanup_interval": self._cleanup_interval,
|
||||
}
|
||||
|
||||
async def _do_shutdown(self):
|
||||
"""清理资源"""
|
||||
# 取消清理任务
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, "_cleanup_task") and self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
202
src/plugins/common.py
Normal file
202
src/plugins/common.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
插件系统通用定义
|
||||
包含所有插件类型共享的类和接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""插件健康状态"""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMetadata:
|
||||
"""插件元数据"""
|
||||
|
||||
name: str
|
||||
version: str = "1.0.0"
|
||||
author: str = "Unknown"
|
||||
description: str = ""
|
||||
api_version: str = "1.0"
|
||||
dependencies: List[str] = None
|
||||
provides: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dependencies is None:
|
||||
self.dependencies = []
|
||||
if self.provides is None:
|
||||
self.provides = []
|
||||
|
||||
|
||||
class BasePlugin(ABC):
|
||||
"""
|
||||
所有插件的基类
|
||||
定义插件的基本生命周期和元数据管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化插件
|
||||
|
||||
Args:
|
||||
name: 插件名称
|
||||
priority: 优先级(数字越大优先级越高)
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖的其他插件
|
||||
provides: 提供的服务
|
||||
config: 配置字典
|
||||
"""
|
||||
self.name = name
|
||||
self.priority = priority
|
||||
self.enabled = True
|
||||
self.config = config or {}
|
||||
self.metadata = PluginMetadata(
|
||||
name=name,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies or [],
|
||||
provides=provides or [],
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
初始化插件
|
||||
|
||||
Returns:
|
||||
初始化成功返回True,失败返回False
|
||||
"""
|
||||
if self._initialized:
|
||||
return True
|
||||
|
||||
try:
|
||||
await self._do_initialize()
|
||||
self._initialized = True
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize plugin {self.name}: {e}")
|
||||
return False
|
||||
|
||||
async def _do_initialize(self):
|
||||
"""
|
||||
子类可以重写此方法来实现特定的初始化逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
"""
|
||||
关闭插件,清理资源
|
||||
"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._do_shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during plugin {self.name} shutdown: {e}")
|
||||
finally:
|
||||
self._initialized = False
|
||||
|
||||
async def _do_shutdown(self):
|
||||
"""
|
||||
子类可以重写此方法来实现特定的清理逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
async def health_check(self) -> HealthStatus:
|
||||
"""
|
||||
检查插件健康状态
|
||||
|
||||
Returns:
|
||||
插件健康状态
|
||||
"""
|
||||
if not self._initialized or not self.enabled:
|
||||
return HealthStatus.UNHEALTHY
|
||||
|
||||
try:
|
||||
return await self._do_health_check()
|
||||
except Exception:
|
||||
return HealthStatus.UNHEALTHY
|
||||
|
||||
async def _do_health_check(self) -> HealthStatus:
|
||||
"""
|
||||
子类可以重写此方法来实现特定的健康检查逻辑
|
||||
默认实现:如果插件已初始化且启用,则认为健康
|
||||
"""
|
||||
return (
|
||||
HealthStatus.HEALTHY if (self._initialized and self.enabled) else HealthStatus.UNHEALTHY
|
||||
)
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置插件
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
"""
|
||||
self.config.update(config)
|
||||
self.enabled = config.get("enabled", True)
|
||||
|
||||
def get_metadata(self) -> PluginMetadata:
|
||||
"""
|
||||
获取插件元数据
|
||||
|
||||
Returns:
|
||||
插件元数据
|
||||
"""
|
||||
return self.metadata
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""检查插件是否已初始化"""
|
||||
return self._initialized
|
||||
|
||||
def validate_dependencies(self, available_plugins: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
验证插件依赖是否满足
|
||||
|
||||
Args:
|
||||
available_plugins: 可用插件字典 {plugin_type: [plugin_names]}
|
||||
|
||||
Returns:
|
||||
缺失的依赖列表
|
||||
"""
|
||||
missing_deps = []
|
||||
for dep in self.metadata.dependencies:
|
||||
found = False
|
||||
for plugin_type, plugin_names in available_plugins.items():
|
||||
if dep in plugin_names:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
missing_deps.append(dep)
|
||||
return missing_deps
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(name={self.name}, priority={self.priority}, enabled={self.enabled}, version={self.metadata.version})>"
|
||||
13
src/plugins/load_balancer/__init__.py
Normal file
13
src/plugins/load_balancer/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
负载均衡策略插件
|
||||
"""
|
||||
|
||||
from .base import LoadBalancerStrategy, ProviderCandidate, SelectionResult
|
||||
from .sticky_priority import StickyPriorityStrategy
|
||||
|
||||
__all__ = [
|
||||
"LoadBalancerStrategy",
|
||||
"ProviderCandidate",
|
||||
"SelectionResult",
|
||||
"StickyPriorityStrategy",
|
||||
]
|
||||
134
src/plugins/load_balancer/base.py
Normal file
134
src/plugins/load_balancer/base.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
负载均衡策略基类
|
||||
定义负载均衡策略的接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..common import BasePlugin
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderCandidate:
|
||||
"""
|
||||
候选提供商信息
|
||||
"""
|
||||
|
||||
provider: Any # Provider 对象
|
||||
priority: int = 0 # 优先级(数字越大优先级越高)
|
||||
weight: float = 1.0 # 权重(影响被选中的概率)
|
||||
model: Optional[Any] = None # Model 对象(如果需要模型信息)
|
||||
metadata: Optional[Dict[str, Any]] = None # 额外元数据
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SelectionResult:
|
||||
"""
|
||||
选择结果
|
||||
"""
|
||||
|
||||
provider: Any # 选中的提供商
|
||||
priority: int # 该提供商的优先级
|
||||
weight: float # 该提供商的权重
|
||||
selection_metadata: Optional[Dict[str, Any]] = None # 选择过程的元数据
|
||||
|
||||
def __post_init__(self):
|
||||
if self.selection_metadata is None:
|
||||
self.selection_metadata = {}
|
||||
|
||||
|
||||
class LoadBalancerStrategy(BasePlugin):
|
||||
"""
|
||||
负载均衡策略基类
|
||||
所有负载均衡策略必须继承此类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化负载均衡策略
|
||||
|
||||
Args:
|
||||
name: 策略名称
|
||||
priority: 优先级(数字越大优先级越高)
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖的其他插件
|
||||
provides: 提供的服务
|
||||
config: 配置字典
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
priority=priority,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies,
|
||||
provides=provides,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def select(
|
||||
self, candidates: List[ProviderCandidate], context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[SelectionResult]:
|
||||
"""
|
||||
从候选提供商中选择一个
|
||||
|
||||
Args:
|
||||
candidates: 候选提供商列表
|
||||
context: 上下文信息(如请求ID、用户信息等)
|
||||
|
||||
Returns:
|
||||
选择结果,如果没有可用提供商则返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取负载均衡统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
pass
|
||||
|
||||
async def record_result(
|
||||
self,
|
||||
provider: Any,
|
||||
success: bool,
|
||||
response_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
):
|
||||
"""
|
||||
记录请求结果(用于动态调整策略)
|
||||
|
||||
Args:
|
||||
provider: 提供商对象
|
||||
success: 是否成功
|
||||
response_time: 响应时间(秒)
|
||||
error: 错误信息(如果失败)
|
||||
"""
|
||||
# 默认实现为空,子类可以重写来实现动态调整
|
||||
pass
|
||||
450
src/plugins/load_balancer/sticky_priority.py
Normal file
450
src/plugins/load_balancer/sticky_priority.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
粘性优先级负载均衡策略
|
||||
正常情况下始终选择同一个提供商(优先级最高+权重最大),只在故障时切换
|
||||
|
||||
WARNING: 多进程环境注意事项
|
||||
=============================
|
||||
此插件的健康状态和粘性缓存存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
|
||||
每个 worker 进程有独立的状态,可能导致:
|
||||
- 不同 worker 看到的提供商健康状态不同
|
||||
- 粘性路由在不同 worker 间不一致
|
||||
- 统计数据分散在各个 worker 中
|
||||
|
||||
解决方案:
|
||||
1. 单 worker 模式:适用于低流量场景
|
||||
2. Redis 共享状态:将 _provider_health 和 _sticky_providers 迁移到 Redis
|
||||
3. 使用独立的健康检查服务:所有 worker 共享同一个健康状态源
|
||||
|
||||
目前项目已有 Redis 依赖,建议在高可用场景下将状态迁移到 Redis。
|
||||
参考:src/services/health_monitor.py 中的实现
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from .base import LoadBalancerStrategy, ProviderCandidate, SelectionResult
|
||||
|
||||
|
||||
|
||||
class StickyPriorityStrategy(LoadBalancerStrategy):
|
||||
"""
|
||||
粘性优先级策略
|
||||
|
||||
选择逻辑:
|
||||
1. 在最高优先级组中,选择权重最大的提供商作为"粘性"提供商
|
||||
2. 正常情况下,始终选择该粘性提供商
|
||||
3. 只有在粘性提供商失败时,才切换到同优先级的其他提供商
|
||||
4. 当粘性提供商恢复后,自动切回
|
||||
|
||||
特点:
|
||||
- 最小化提供商切换,流量集中在单一提供商
|
||||
- 自动故障转移和恢复
|
||||
- 适合需要集中使用某个API Key的场景
|
||||
|
||||
Note:
|
||||
状态存储在进程内存中,多进程部署时各 worker 状态独立。
|
||||
详见模块文档说明。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
config = config or {} # 确保 config 不为 None
|
||||
super().__init__(
|
||||
name="sticky_priority",
|
||||
priority=110, # 比默认的 priority_weighted 更高
|
||||
version="1.0.0",
|
||||
author="System",
|
||||
description="粘性优先级负载均衡策略,正常时始终使用同一提供商",
|
||||
api_version="1.0",
|
||||
provides=["load_balancer"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 配置参数
|
||||
self.failure_threshold = config.get("failure_threshold", 3) # 连续失败阈值
|
||||
self.recovery_delay = config.get("recovery_delay", 30) # 恢复延迟(秒)
|
||||
self.enable_auto_recovery = config.get("enable_auto_recovery", True) # 是否自动恢复
|
||||
|
||||
# 提供商健康状态追踪 {provider_id: health_info}
|
||||
self._provider_health: Dict[str, Dict[str, Any]] = defaultdict(
|
||||
lambda: {
|
||||
"consecutive_failures": 0,
|
||||
"last_failure_time": None,
|
||||
"is_healthy": True,
|
||||
"total_requests": 0,
|
||||
"total_failures": 0,
|
||||
}
|
||||
)
|
||||
|
||||
# 当前粘性提供商缓存 {cache_key: provider_id}
|
||||
# cache_key 可以是 api_key_id 或者其他标识
|
||||
self._sticky_providers: Dict[str, str] = {}
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"total_selections": 0,
|
||||
"provider_selections": {},
|
||||
"sticky_hits": 0, # 选择粘性提供商的次数
|
||||
"failovers": 0, # 故障切换次数
|
||||
"auto_recoveries": 0, # 自动恢复次数
|
||||
}
|
||||
|
||||
async def select(
|
||||
self, candidates: List[ProviderCandidate], context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[SelectionResult]:
|
||||
"""
|
||||
从候选提供商中选择一个
|
||||
|
||||
Args:
|
||||
candidates: 候选提供商列表
|
||||
context: 上下文信息(包含 api_key_id 等)
|
||||
|
||||
Returns:
|
||||
选择结果
|
||||
"""
|
||||
if not candidates:
|
||||
logger.warning("No candidates available for selection")
|
||||
return None
|
||||
|
||||
if len(candidates) == 1:
|
||||
candidate = candidates[0]
|
||||
self._record_selection(candidate.provider, is_sticky=True)
|
||||
return SelectionResult(
|
||||
provider=candidate.provider,
|
||||
priority=candidate.priority,
|
||||
weight=candidate.weight,
|
||||
selection_metadata={"strategy": "single_candidate"},
|
||||
)
|
||||
|
||||
# 获取缓存键(用于识别同一请求源)
|
||||
cache_key = self._get_cache_key(context)
|
||||
|
||||
# 按优先级分组
|
||||
priority_groups = self._group_by_priority(candidates)
|
||||
highest_priority = max(priority_groups.keys())
|
||||
highest_group = priority_groups[highest_priority]
|
||||
|
||||
# 确定粘性提供商
|
||||
sticky_candidate = self._determine_sticky_provider(highest_group, cache_key, context)
|
||||
|
||||
# 检查粘性提供商是否健康
|
||||
provider_id = str(sticky_candidate.provider.id)
|
||||
health_info = self._provider_health[provider_id]
|
||||
|
||||
# 如果粘性提供商健康,直接使用
|
||||
if health_info["is_healthy"]:
|
||||
self._record_selection(sticky_candidate.provider, is_sticky=True)
|
||||
|
||||
logger.info(f"Selected sticky provider {sticky_candidate.provider.name}")
|
||||
|
||||
return SelectionResult(
|
||||
provider=sticky_candidate.provider,
|
||||
priority=sticky_candidate.priority,
|
||||
weight=sticky_candidate.weight,
|
||||
selection_metadata={
|
||||
"strategy": "sticky_priority",
|
||||
"is_sticky": True,
|
||||
"cache_key": cache_key,
|
||||
"health_status": "healthy",
|
||||
},
|
||||
)
|
||||
|
||||
# 粘性提供商不健康,选择备用提供商
|
||||
logger.warning(f"Sticky provider {sticky_candidate.provider.name} is unhealthy, selecting backup")
|
||||
|
||||
# 从同一优先级组中选择健康的备用提供商
|
||||
backup_candidate = self._select_backup_provider(highest_group)
|
||||
|
||||
if not backup_candidate:
|
||||
# 如果没有健康的备用,降级使用不健康的粘性提供商
|
||||
logger.warning("No healthy backup provider available, falling back to unhealthy sticky provider")
|
||||
backup_candidate = sticky_candidate
|
||||
|
||||
self._record_selection(backup_candidate.provider, is_sticky=False)
|
||||
self._stats["failovers"] += 1
|
||||
|
||||
logger.info(f"Selected backup provider {backup_candidate.provider.name}")
|
||||
|
||||
return SelectionResult(
|
||||
provider=backup_candidate.provider,
|
||||
priority=backup_candidate.priority,
|
||||
weight=backup_candidate.weight,
|
||||
selection_metadata={
|
||||
"strategy": "sticky_priority",
|
||||
"is_sticky": False,
|
||||
"is_failover": True,
|
||||
"original_provider_id": provider_id,
|
||||
"health_status": "backup",
|
||||
},
|
||||
)
|
||||
|
||||
def _get_cache_key(self, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
生成缓存键,用于识别同一请求源
|
||||
|
||||
Args:
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
缓存键
|
||||
"""
|
||||
if not context:
|
||||
return "default"
|
||||
|
||||
# 优先使用 api_key_id
|
||||
if "api_key_id" in context:
|
||||
return f"api_key_{context['api_key_id']}"
|
||||
|
||||
# 其他标识
|
||||
if "user_id" in context:
|
||||
return f"user_{context['user_id']}"
|
||||
|
||||
return "default"
|
||||
|
||||
def _group_by_priority(
|
||||
self, candidates: List[ProviderCandidate]
|
||||
) -> Dict[int, List[ProviderCandidate]]:
|
||||
"""按优先级分组候选提供商"""
|
||||
groups: Dict[int, List[ProviderCandidate]] = {}
|
||||
for candidate in candidates:
|
||||
priority = candidate.priority
|
||||
if priority not in groups:
|
||||
groups[priority] = []
|
||||
groups[priority].append(candidate)
|
||||
return groups
|
||||
|
||||
def _determine_sticky_provider(
|
||||
self,
|
||||
candidates: List[ProviderCandidate],
|
||||
cache_key: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> ProviderCandidate:
|
||||
"""
|
||||
确定粘性提供商
|
||||
|
||||
策略:
|
||||
1. 如果已有缓存的粘性提供商,检查是否仍在候选列表中
|
||||
2. 如果没有或已失效,选择权重最大的作为新的粘性提供商
|
||||
|
||||
Args:
|
||||
candidates: 同一优先级的候选列表
|
||||
cache_key: 缓存键
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
粘性提供商候选
|
||||
"""
|
||||
# 检查缓存的粘性提供商
|
||||
if cache_key in self._sticky_providers:
|
||||
cached_provider_id = self._sticky_providers[cache_key]
|
||||
|
||||
# 查找是否仍在候选列表中
|
||||
for candidate in candidates:
|
||||
if str(candidate.provider.id) == cached_provider_id:
|
||||
# 检查是否可以自动恢复
|
||||
if self._can_auto_recover(cached_provider_id):
|
||||
logger.info(f"Auto-recovering sticky provider {candidate.provider.name}")
|
||||
self._stats["auto_recoveries"] += 1
|
||||
# 重置健康状态
|
||||
self._provider_health[cached_provider_id]["is_healthy"] = True
|
||||
self._provider_health[cached_provider_id]["consecutive_failures"] = 0
|
||||
|
||||
return candidate
|
||||
|
||||
# 没有缓存或缓存失效,选择权重最大的
|
||||
sticky_candidate = max(candidates, key=lambda c: c.weight)
|
||||
self._sticky_providers[cache_key] = str(sticky_candidate.provider.id)
|
||||
|
||||
logger.info(f"Set new sticky provider {sticky_candidate.provider.name}")
|
||||
|
||||
return sticky_candidate
|
||||
|
||||
def _can_auto_recover(self, provider_id: str) -> bool:
|
||||
"""
|
||||
检查提供商是否可以自动恢复
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
|
||||
Returns:
|
||||
是否可以恢复
|
||||
"""
|
||||
if not self.enable_auto_recovery:
|
||||
return False
|
||||
|
||||
health_info = self._provider_health[provider_id]
|
||||
|
||||
# 如果已经是健康状态,直接返回 True
|
||||
if health_info["is_healthy"]:
|
||||
return True
|
||||
|
||||
# 检查是否超过恢复延迟
|
||||
if health_info["last_failure_time"]:
|
||||
time_since_failure = time.time() - health_info["last_failure_time"]
|
||||
if time_since_failure >= self.recovery_delay:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _select_backup_provider(
|
||||
self, candidates: List[ProviderCandidate]
|
||||
) -> Optional[ProviderCandidate]:
|
||||
"""
|
||||
从候选列表中选择健康的备用提供商
|
||||
|
||||
优先选择权重最大且健康的提供商
|
||||
|
||||
Args:
|
||||
candidates: 候选提供商列表
|
||||
|
||||
Returns:
|
||||
备用提供商,如果没有健康的则返回 None
|
||||
"""
|
||||
healthy_candidates = []
|
||||
|
||||
for candidate in candidates:
|
||||
provider_id = str(candidate.provider.id)
|
||||
health_info = self._provider_health[provider_id]
|
||||
|
||||
# 检查是否可以自动恢复
|
||||
if health_info["is_healthy"] or self._can_auto_recover(provider_id):
|
||||
if not health_info["is_healthy"]:
|
||||
# 自动恢复
|
||||
health_info["is_healthy"] = True
|
||||
health_info["consecutive_failures"] = 0
|
||||
self._stats["auto_recoveries"] += 1
|
||||
|
||||
healthy_candidates.append(candidate)
|
||||
|
||||
if not healthy_candidates:
|
||||
return None
|
||||
|
||||
# 选择权重最大的健康提供商
|
||||
return max(healthy_candidates, key=lambda c: c.weight)
|
||||
|
||||
def _record_selection(self, provider: Any, is_sticky: bool = True):
|
||||
"""记录选择统计"""
|
||||
self._stats["total_selections"] += 1
|
||||
provider_id = str(provider.id)
|
||||
|
||||
if provider_id not in self._stats["provider_selections"]:
|
||||
self._stats["provider_selections"][provider_id] = 0
|
||||
self._stats["provider_selections"][provider_id] += 1
|
||||
|
||||
if is_sticky:
|
||||
self._stats["sticky_hits"] += 1
|
||||
|
||||
async def record_result(
|
||||
self,
|
||||
provider: Any,
|
||||
success: bool,
|
||||
response_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
):
|
||||
"""
|
||||
记录请求结果,更新健康状态
|
||||
|
||||
Args:
|
||||
provider: 提供商对象
|
||||
success: 是否成功
|
||||
response_time: 响应时间(秒)
|
||||
error: 错误信息(如果失败)
|
||||
"""
|
||||
provider_id = str(provider.id)
|
||||
health_info = self._provider_health[provider_id]
|
||||
|
||||
health_info["total_requests"] += 1
|
||||
|
||||
if success:
|
||||
# 成功,重置连续失败计数
|
||||
health_info["consecutive_failures"] = 0
|
||||
health_info["is_healthy"] = True
|
||||
|
||||
logger.debug(f"Recorded successful result for provider {provider.name}")
|
||||
else:
|
||||
# 失败,增加连续失败计数
|
||||
health_info["consecutive_failures"] += 1
|
||||
health_info["total_failures"] += 1
|
||||
health_info["last_failure_time"] = time.time()
|
||||
|
||||
# 检查是否达到失败阈值
|
||||
if health_info["consecutive_failures"] >= self.failure_threshold:
|
||||
health_info["is_healthy"] = False
|
||||
|
||||
logger.warning(f"Provider {provider.name} marked as unhealthy")
|
||||
else:
|
||||
logger.debug(f"Recorded failed result for provider {provider.name}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
# 计算健康状态
|
||||
healthy_count = sum(1 for info in self._provider_health.values() if info["is_healthy"])
|
||||
|
||||
total_providers = len(self._provider_health)
|
||||
|
||||
# 计算粘性命中率
|
||||
sticky_hit_rate = 0.0
|
||||
if self._stats["total_selections"] > 0:
|
||||
sticky_hit_rate = self._stats["sticky_hits"] / self._stats["total_selections"]
|
||||
|
||||
return {
|
||||
"strategy": "sticky_priority",
|
||||
"total_selections": self._stats["total_selections"],
|
||||
"provider_selections": self._stats["provider_selections"],
|
||||
"sticky_hits": self._stats["sticky_hits"],
|
||||
"sticky_hit_rate": sticky_hit_rate,
|
||||
"failovers": self._stats["failovers"],
|
||||
"auto_recoveries": self._stats["auto_recoveries"],
|
||||
"healthy_providers": healthy_count,
|
||||
"total_providers": total_providers,
|
||||
"provider_health": {
|
||||
provider_id: {
|
||||
"is_healthy": info["is_healthy"],
|
||||
"consecutive_failures": info["consecutive_failures"],
|
||||
"total_requests": info["total_requests"],
|
||||
"total_failures": info["total_failures"],
|
||||
"failure_rate": (
|
||||
info["total_failures"] / info["total_requests"]
|
||||
if info["total_requests"] > 0
|
||||
else 0
|
||||
),
|
||||
"last_failure_time": info["last_failure_time"],
|
||||
}
|
||||
for provider_id, info in self._provider_health.items()
|
||||
},
|
||||
"sticky_providers": self._sticky_providers,
|
||||
"config": {
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"recovery_delay": self.recovery_delay,
|
||||
"enable_auto_recovery": self.enable_auto_recovery,
|
||||
},
|
||||
}
|
||||
|
||||
async def reset_provider_health(self, provider_id: str):
|
||||
"""重置指定提供商的健康状态"""
|
||||
if provider_id in self._provider_health:
|
||||
self._provider_health[provider_id] = {
|
||||
"consecutive_failures": 0,
|
||||
"last_failure_time": None,
|
||||
"is_healthy": True,
|
||||
"total_requests": 0,
|
||||
"total_failures": 0,
|
||||
}
|
||||
logger.info(f"Reset health status for provider {provider_id}")
|
||||
|
||||
async def clear_sticky_cache(self, cache_key: Optional[str] = None):
|
||||
"""
|
||||
清除粘性提供商缓存
|
||||
|
||||
Args:
|
||||
cache_key: 指定要清除的缓存键,None 则清除全部
|
||||
"""
|
||||
if cache_key:
|
||||
if cache_key in self._sticky_providers:
|
||||
del self._sticky_providers[cache_key]
|
||||
logger.info(f"Cleared sticky provider cache for key: {cache_key}")
|
||||
else:
|
||||
self._sticky_providers.clear()
|
||||
logger.info("Cleared all sticky provider cache")
|
||||
579
src/plugins/manager.py
Normal file
579
src/plugins/manager.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""
|
||||
插件管理器
|
||||
统一管理和协调所有插件系统
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.plugins.auth.base import AuthPlugin
|
||||
from src.plugins.cache.base import CachePlugin
|
||||
|
||||
# 移除审计插件 - 审计功能现在是核心服务,不再作为插件
|
||||
from src.plugins.common import BasePlugin, HealthStatus, PluginMetadata
|
||||
from src.plugins.load_balancer.base import LoadBalancerStrategy
|
||||
from src.plugins.monitor.base import MonitorPlugin
|
||||
from src.plugins.notification.base import NotificationPlugin
|
||||
from src.plugins.rate_limit.base import RateLimitStrategy
|
||||
from src.plugins.token.base import TokenCounterPlugin
|
||||
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""
|
||||
统一的插件管理器
|
||||
负责加载、配置和管理所有类型的插件
|
||||
"""
|
||||
|
||||
# 当前支持的 API 版本
|
||||
SUPPORTED_API_VERSION = "1.0"
|
||||
|
||||
# 插件类型映射
|
||||
PLUGIN_TYPES = {
|
||||
"auth": AuthPlugin,
|
||||
"rate_limit": RateLimitStrategy,
|
||||
"cache": CachePlugin,
|
||||
"monitor": MonitorPlugin,
|
||||
"token": TokenCounterPlugin,
|
||||
"notification": NotificationPlugin,
|
||||
"load_balancer": LoadBalancerStrategy,
|
||||
# 移除 "audit" - 审计功能现在是核心服务
|
||||
}
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化插件管理器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.plugins: Dict[str, Dict[str, Any]] = {
|
||||
"auth": {},
|
||||
"rate_limit": {},
|
||||
"cache": {},
|
||||
"monitor": {},
|
||||
"token": {},
|
||||
"notification": {},
|
||||
"load_balancer": {},
|
||||
# 移除 "audit" - 审计功能现在是核心服务
|
||||
}
|
||||
self.default_plugins: Dict[str, Optional[str]] = {
|
||||
"auth": None,
|
||||
"rate_limit": None,
|
||||
"cache": None,
|
||||
"monitor": None,
|
||||
"token": None,
|
||||
"notification": None,
|
||||
"load_balancer": "sticky_priority", # 默认使用粘性优先级策略
|
||||
# 移除 "audit" - 审计功能现在是核心服务
|
||||
}
|
||||
# 跟踪因版本不兼容而跳过的插件
|
||||
self._incompatible_plugins: List[str] = []
|
||||
|
||||
# 自动发现和加载插件
|
||||
self._auto_discover_plugins()
|
||||
|
||||
# 应用配置
|
||||
self._apply_config()
|
||||
|
||||
def _auto_discover_plugins(self):
|
||||
"""自动发现和加载插件"""
|
||||
plugins_dir = Path(__file__).parent
|
||||
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
type_dir = plugins_dir / plugin_type
|
||||
if not type_dir.exists():
|
||||
continue
|
||||
|
||||
# 扫描插件目录
|
||||
for file_path in type_dir.glob("*.py"):
|
||||
if file_path.name.startswith("_") or file_path.name == "base.py":
|
||||
continue
|
||||
|
||||
module_name = f"src.plugins.{plugin_type}.{file_path.stem}"
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
self._load_plugin_from_module(module, plugin_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load plugin module {module_name}: {e}")
|
||||
|
||||
def _is_api_version_compatible(self, plugin_api_version: str) -> bool:
|
||||
"""
|
||||
检查插件 API 版本是否兼容
|
||||
|
||||
采用语义化版本的主版本号兼容策略:
|
||||
- 主版本号相同则兼容
|
||||
- 例如: 支持版本 "1.0",插件版本 "1.0", "1.1", "1.2" 都兼容
|
||||
|
||||
Args:
|
||||
plugin_api_version: 插件声明的 API 版本
|
||||
|
||||
Returns:
|
||||
是否兼容
|
||||
"""
|
||||
try:
|
||||
supported_major = self.SUPPORTED_API_VERSION.split(".")[0]
|
||||
plugin_major = plugin_api_version.split(".")[0]
|
||||
return supported_major == plugin_major
|
||||
except (ValueError, IndexError):
|
||||
# 解析失败,假设兼容
|
||||
return True
|
||||
|
||||
def _load_plugin_from_module(self, module: Any, plugin_type: str):
|
||||
"""从模块加载插件类"""
|
||||
base_class = self.PLUGIN_TYPES[plugin_type]
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, base_class) and obj != base_class:
|
||||
# 实例化插件
|
||||
try:
|
||||
plugin_instance = obj()
|
||||
|
||||
# 检查 API 版本兼容性
|
||||
plugin_api_version = getattr(plugin_instance.metadata, "api_version", "1.0")
|
||||
if not self._is_api_version_compatible(plugin_api_version):
|
||||
logger.warning(f"Plugin {plugin_instance.name} has incompatible API version "
|
||||
f"{plugin_api_version} (supported: {self.SUPPORTED_API_VERSION}), "
|
||||
f"plugin will be disabled")
|
||||
plugin_instance.enabled = False
|
||||
self._incompatible_plugins.append(plugin_instance.name)
|
||||
|
||||
self.register_plugin(plugin_type, plugin_instance)
|
||||
logger.info(f"Loaded {plugin_type} plugin: {plugin_instance.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate plugin {name}: {e}")
|
||||
|
||||
def _apply_config(self):
|
||||
"""应用配置到插件"""
|
||||
for plugin_type, plugins in self.plugins.items():
|
||||
type_config = self.config.get(plugin_type, {})
|
||||
|
||||
# 设置默认插件
|
||||
if "default" in type_config:
|
||||
self.default_plugins[plugin_type] = type_config["default"]
|
||||
|
||||
# 配置各个插件
|
||||
for plugin_name, plugin in plugins.items():
|
||||
plugin_config = type_config.get(plugin_name, {})
|
||||
if plugin_config:
|
||||
plugin.configure(plugin_config)
|
||||
|
||||
def register_plugin(self, plugin_type: str, plugin: Any, set_as_default: bool = False):
|
||||
"""
|
||||
注册插件
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
plugin: 插件实例
|
||||
set_as_default: 是否设为默认
|
||||
"""
|
||||
if plugin_type not in self.plugins:
|
||||
raise ValueError(f"Unknown plugin type: {plugin_type}")
|
||||
|
||||
# 验证插件类型
|
||||
base_class = self.PLUGIN_TYPES[plugin_type]
|
||||
if not isinstance(plugin, base_class):
|
||||
raise TypeError(
|
||||
f"Plugin must be instance of {base_class.__name__}, " f"got {type(plugin).__name__}"
|
||||
)
|
||||
|
||||
# 注册插件
|
||||
self.plugins[plugin_type][plugin.name] = plugin
|
||||
|
||||
# 设为默认
|
||||
if set_as_default or not self.default_plugins[plugin_type]:
|
||||
self.default_plugins[plugin_type] = plugin.name
|
||||
|
||||
logger.debug(f"Registered {plugin_type} plugin: {plugin.name}")
|
||||
|
||||
def unregister_plugin(self, plugin_type: str, plugin_name: str):
|
||||
"""
|
||||
注销插件
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
plugin_name: 插件名称
|
||||
"""
|
||||
if plugin_type in self.plugins:
|
||||
if plugin_name in self.plugins[plugin_type]:
|
||||
del self.plugins[plugin_type][plugin_name]
|
||||
|
||||
# 如果是默认插件,清除默认设置
|
||||
if self.default_plugins[plugin_type] == plugin_name:
|
||||
self.default_plugins[plugin_type] = None
|
||||
|
||||
logger.debug(f"Unregistered {plugin_type} plugin: {plugin_name}")
|
||||
|
||||
def get_plugin(self, plugin_type: str, plugin_name: Optional[str] = None) -> Optional[Any]:
|
||||
"""
|
||||
获取插件实例
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
plugin_name: 插件名称,不指定则返回默认插件
|
||||
|
||||
Returns:
|
||||
插件实例,如果不存在返回None
|
||||
"""
|
||||
if plugin_type not in self.plugins:
|
||||
return None
|
||||
|
||||
if plugin_name:
|
||||
return self.plugins[plugin_type].get(plugin_name)
|
||||
|
||||
# 返回默认插件
|
||||
default_name = self.default_plugins[plugin_type]
|
||||
if default_name:
|
||||
return self.plugins[plugin_type].get(default_name)
|
||||
|
||||
# 如果没有默认插件,返回第一个可用的
|
||||
if self.plugins[plugin_type]:
|
||||
return next(iter(self.plugins[plugin_type].values()))
|
||||
|
||||
return None
|
||||
|
||||
def get_plugins_by_type(self, plugin_type: str) -> List[Any]:
|
||||
"""
|
||||
获取某个类型的所有插件
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
|
||||
Returns:
|
||||
插件列表
|
||||
"""
|
||||
if plugin_type not in self.plugins:
|
||||
return []
|
||||
|
||||
return list(self.plugins[plugin_type].values())
|
||||
|
||||
def get_enabled_plugins(self, plugin_type: str) -> List[Any]:
|
||||
"""
|
||||
获取某个类型的所有启用的插件
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
|
||||
Returns:
|
||||
启用的插件列表
|
||||
"""
|
||||
plugins = self.get_plugins_by_type(plugin_type)
|
||||
return [p for p in plugins if getattr(p, "enabled", True)]
|
||||
|
||||
async def execute_plugin_chain(
|
||||
self, plugin_type: str, method_name: str, *args, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
执行插件链(按优先级)
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
method_name: 要调用的方法名
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
第一个成功的结果
|
||||
"""
|
||||
plugins = self.get_enabled_plugins(plugin_type)
|
||||
|
||||
# 按优先级排序(如果有priority属性)
|
||||
plugins.sort(key=lambda p: getattr(p, "priority", 0), reverse=True)
|
||||
|
||||
for plugin in plugins:
|
||||
if hasattr(plugin, method_name):
|
||||
method = getattr(plugin, method_name)
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
result = await method(*args, **kwargs)
|
||||
else:
|
||||
result = method(*args, **kwargs)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin {plugin.name} failed in {method_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件管理器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
stats = {
|
||||
"supported_api_version": self.SUPPORTED_API_VERSION,
|
||||
"plugin_counts": {},
|
||||
"enabled_counts": {},
|
||||
"default_plugins": self.default_plugins,
|
||||
"plugin_details": {},
|
||||
"incompatible_plugins": self._incompatible_plugins,
|
||||
}
|
||||
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
all_plugins = self.get_plugins_by_type(plugin_type)
|
||||
enabled_plugins = self.get_enabled_plugins(plugin_type)
|
||||
|
||||
stats["plugin_counts"][plugin_type] = len(all_plugins)
|
||||
stats["enabled_counts"][plugin_type] = len(enabled_plugins)
|
||||
|
||||
# 详细信息
|
||||
stats["plugin_details"][plugin_type] = [
|
||||
{
|
||||
"name": p.name,
|
||||
"enabled": getattr(p, "enabled", True),
|
||||
"priority": getattr(p, "priority", 0),
|
||||
"class": type(p).__name__,
|
||||
"api_version": getattr(p.metadata, "api_version", "unknown"),
|
||||
"version": getattr(p.metadata, "version", "unknown"),
|
||||
}
|
||||
for p in all_plugins
|
||||
]
|
||||
|
||||
return stats
|
||||
|
||||
async def initialize_all(self) -> Dict[str, bool]:
|
||||
"""
|
||||
初始化所有插件
|
||||
|
||||
初始化失败的插件会被自动禁用,防止后续使用未正确初始化的插件。
|
||||
|
||||
Returns:
|
||||
初始化结果字典 {plugin_name: success}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 获取所有插件并按依赖顺序排序
|
||||
all_plugins = []
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
all_plugins.extend(self.get_plugins_by_type(plugin_type))
|
||||
|
||||
# 拓扑排序处理依赖
|
||||
sorted_plugins = self._sort_plugins_by_dependencies(all_plugins)
|
||||
|
||||
# 按顺序初始化插件
|
||||
for plugin in sorted_plugins:
|
||||
try:
|
||||
# 检查插件是否有 initialize 方法
|
||||
if not hasattr(plugin, "initialize"):
|
||||
# 如果没有 initialize 方法,假设插件已经初始化完成
|
||||
logger.debug(f"Plugin {plugin.name} has no initialize() method, skipping")
|
||||
results[f"{plugin.name}"] = True
|
||||
continue
|
||||
|
||||
success = await plugin.initialize()
|
||||
results[f"{plugin.name}"] = success
|
||||
if success:
|
||||
logger.info(f"Successfully initialized plugin: {plugin.name}")
|
||||
else:
|
||||
# 初始化失败,禁用插件
|
||||
plugin.enabled = False
|
||||
logger.error(f"Failed to initialize plugin: {plugin.name}, plugin has been disabled")
|
||||
except Exception as e:
|
||||
results[f"{plugin.name}"] = False
|
||||
# 初始化异常,禁用插件
|
||||
plugin.enabled = False
|
||||
logger.error(f"Error initializing plugin {plugin.name}: {e}, plugin has been disabled")
|
||||
|
||||
return results
|
||||
|
||||
async def shutdown_all(self):
|
||||
"""
|
||||
关闭所有插件
|
||||
"""
|
||||
# 获取所有插件并按依赖顺序反向排序(先关闭依赖者)
|
||||
all_plugins = []
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
all_plugins.extend(self.get_plugins_by_type(plugin_type))
|
||||
|
||||
sorted_plugins = self._sort_plugins_by_dependencies(all_plugins)
|
||||
sorted_plugins.reverse() # 反向关闭
|
||||
|
||||
# 并发关闭插件
|
||||
shutdown_tasks = []
|
||||
for plugin in sorted_plugins:
|
||||
# 只关闭有 shutdown 方法的插件
|
||||
if hasattr(plugin, "shutdown"):
|
||||
shutdown_tasks.append(plugin.shutdown())
|
||||
|
||||
if shutdown_tasks:
|
||||
try:
|
||||
await asyncio.gather(*shutdown_tasks, return_exceptions=True)
|
||||
logger.info("All plugins shut down")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during plugin shutdown: {e}")
|
||||
|
||||
def _sort_plugins_by_dependencies(self, plugins: List[BasePlugin]) -> List[BasePlugin]:
|
||||
"""
|
||||
按依赖关系对插件进行拓扑排序
|
||||
|
||||
Args:
|
||||
plugins: 插件列表
|
||||
|
||||
Returns:
|
||||
排序后的插件列表
|
||||
|
||||
Note:
|
||||
存在循环依赖的插件会被自动禁用
|
||||
"""
|
||||
# 创建插件名称到插件对象的映射
|
||||
plugin_map = {plugin.name: plugin for plugin in plugins}
|
||||
|
||||
# 计算每个插件的入度(被依赖的次数)
|
||||
in_degree = {plugin.name: 0 for plugin in plugins}
|
||||
|
||||
# 构建依赖图
|
||||
for plugin in plugins:
|
||||
for dep in plugin.metadata.dependencies:
|
||||
if dep in in_degree:
|
||||
in_degree[plugin.name] += 1
|
||||
|
||||
# 拓扑排序
|
||||
queue = [name for name, degree in in_degree.items() if degree == 0]
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
result.append(plugin_map[current])
|
||||
|
||||
# 减少依赖当前插件的其他插件的入度
|
||||
current_plugin = plugin_map[current]
|
||||
for plugin in plugins:
|
||||
if current in plugin.metadata.dependencies:
|
||||
in_degree[plugin.name] -= 1
|
||||
if in_degree[plugin.name] == 0:
|
||||
queue.append(plugin.name)
|
||||
|
||||
# 检查是否存在循环依赖
|
||||
if len(result) != len(plugins):
|
||||
remaining = [p for p in plugins if p not in result]
|
||||
circular_names = [p.name for p in remaining]
|
||||
logger.error(f"Circular dependency detected among plugins: {circular_names}. "
|
||||
f"These plugins will be disabled.")
|
||||
# 禁用存在循环依赖的插件,而不是继续加载
|
||||
for plugin in remaining:
|
||||
plugin.enabled = False
|
||||
logger.warning(f"Plugin {plugin.name} has been disabled due to circular dependency")
|
||||
# 不再将循环依赖的插件添加到结果中
|
||||
|
||||
return result
|
||||
|
||||
async def health_check_all(self) -> Dict[str, HealthStatus]:
|
||||
"""
|
||||
检查所有插件的健康状态
|
||||
|
||||
Returns:
|
||||
健康状态字典 {plugin_name: status}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 获取所有插件
|
||||
all_plugins = []
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
all_plugins.extend(self.get_plugins_by_type(plugin_type))
|
||||
|
||||
# 并发检查健康状态
|
||||
health_tasks = []
|
||||
for plugin in all_plugins:
|
||||
health_tasks.append(plugin.health_check())
|
||||
|
||||
if health_tasks:
|
||||
health_results = await asyncio.gather(*health_tasks, return_exceptions=True)
|
||||
|
||||
for plugin, result in zip(all_plugins, health_results):
|
||||
if isinstance(result, Exception):
|
||||
results[plugin.name] = HealthStatus.UNHEALTHY
|
||||
else:
|
||||
results[plugin.name] = result
|
||||
|
||||
return results
|
||||
|
||||
def validate_plugin_dependencies(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
验证所有插件的依赖关系
|
||||
|
||||
Returns:
|
||||
验证结果字典 {plugin_name: [missing_dependencies]}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 获取所有可用插件名称
|
||||
available_plugins = {}
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
available_plugins[plugin_type] = [p.name for p in self.get_plugins_by_type(plugin_type)]
|
||||
|
||||
# 检查每个插件的依赖
|
||||
for plugin_type in self.PLUGIN_TYPES:
|
||||
for plugin in self.get_plugins_by_type(plugin_type):
|
||||
missing_deps = plugin.validate_dependencies(available_plugins)
|
||||
if missing_deps:
|
||||
results[plugin.name] = missing_deps
|
||||
|
||||
return results
|
||||
|
||||
def reload_plugin_config(
|
||||
self, plugin_type: str, plugin_name: str, new_config: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
重新加载插件配置
|
||||
|
||||
Args:
|
||||
plugin_type: 插件类型
|
||||
plugin_name: 插件名称
|
||||
new_config: 新配置
|
||||
|
||||
Returns:
|
||||
是否成功重新加载
|
||||
"""
|
||||
plugin = self.get_plugin(plugin_type, plugin_name)
|
||||
if not plugin:
|
||||
return False
|
||||
|
||||
try:
|
||||
plugin.configure(new_config)
|
||||
logger.info(f"Reloaded config for plugin {plugin_name}: {new_config}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload config for plugin {plugin_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
_plugin_manager: Optional[PluginManager] = None
|
||||
_plugin_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_plugin_manager(config: Optional[Dict[str, Any]] = None) -> PluginManager:
|
||||
"""
|
||||
获取全局插件管理器实例(线程安全)
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
插件管理器实例
|
||||
"""
|
||||
global _plugin_manager
|
||||
|
||||
if _plugin_manager is None:
|
||||
with _plugin_manager_lock:
|
||||
# 双重检查锁定模式
|
||||
if _plugin_manager is None:
|
||||
_plugin_manager = PluginManager(config)
|
||||
|
||||
return _plugin_manager
|
||||
|
||||
|
||||
def reset_plugin_manager():
|
||||
"""重置插件管理器(用于测试)"""
|
||||
global _plugin_manager
|
||||
with _plugin_manager_lock:
|
||||
_plugin_manager = None
|
||||
5
src/plugins/monitor/__init__.py
Normal file
5
src/plugins/monitor/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""监控插件包"""
|
||||
|
||||
from .base import Metric, MetricType, MonitorPlugin
|
||||
|
||||
__all__ = ["MonitorPlugin", "Metric", "MetricType"]
|
||||
250
src/plugins/monitor/base.py
Normal file
250
src/plugins/monitor/base.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
监控插件基类
|
||||
定义监控和指标收集的接口
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.plugins.common import BasePlugin
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
"""指标类型"""
|
||||
|
||||
COUNTER = "counter" # 计数器(只增不减)
|
||||
GAUGE = "gauge" # 仪表(可增可减)
|
||||
HISTOGRAM = "histogram" # 直方图(分布)
|
||||
SUMMARY = "summary" # 摘要(分位数)
|
||||
|
||||
|
||||
class Metric:
|
||||
"""指标数据"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
value: float,
|
||||
metric_type: MetricType,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.metric_type = metric_type
|
||||
self.labels = labels or {}
|
||||
self.timestamp = timestamp or datetime.now(timezone.utc)
|
||||
self.description = description
|
||||
|
||||
|
||||
class MonitorPlugin(BasePlugin):
|
||||
"""
|
||||
监控插件基类
|
||||
所有监控插件必须继承此类并实现相关方法
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: Dict[str, Any] = None):
|
||||
"""
|
||||
初始化监控插件
|
||||
|
||||
Args:
|
||||
name: 插件名称
|
||||
config: 配置字典
|
||||
"""
|
||||
# 调用父类初始化,设置metadata
|
||||
super().__init__(name=name, config=config, description="Monitor Plugin", version="1.0.0")
|
||||
|
||||
self.flush_interval = self.config.get("flush_interval", 60)
|
||||
self.batch_size = self.config.get("batch_size", 100)
|
||||
|
||||
@abstractmethod
|
||||
async def record_metric(self, metric: Metric):
|
||||
"""
|
||||
记录单个指标
|
||||
|
||||
Args:
|
||||
metric: 指标数据
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def record_batch(self, metrics: List[Metric]):
|
||||
"""
|
||||
批量记录指标
|
||||
|
||||
Args:
|
||||
metrics: 指标列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def increment(self, name: str, value: float = 1, labels: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
增加计数器
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
value: 增加的值
|
||||
labels: 标签字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
设置仪表值
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
value: 仪表值
|
||||
labels: 标签字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def histogram(
|
||||
self,
|
||||
name: str,
|
||||
value: float,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
buckets: Optional[List[float]] = None,
|
||||
):
|
||||
"""
|
||||
记录直方图数据
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
value: 观测值
|
||||
labels: 标签字典
|
||||
buckets: 桶边界
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def timing(self, name: str, duration: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
记录时间指标
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
duration: 持续时间(秒)
|
||||
labels: 标签字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self):
|
||||
"""
|
||||
刷新缓冲的指标到后端
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
pass
|
||||
|
||||
def record_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
status_code: int,
|
||||
duration: float,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
记录API请求指标(便捷方法)
|
||||
|
||||
Args:
|
||||
method: HTTP方法
|
||||
endpoint: 端点路径
|
||||
status_code: 状态码
|
||||
duration: 请求时长
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
"""
|
||||
labels = {
|
||||
"method": method,
|
||||
"endpoint": endpoint,
|
||||
"status": str(status_code),
|
||||
"status_class": f"{status_code // 100}xx",
|
||||
}
|
||||
|
||||
if provider:
|
||||
labels["provider"] = provider
|
||||
if model:
|
||||
labels["model"] = model
|
||||
|
||||
# 异步记录指标
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 请求计数
|
||||
loop.create_task(self.increment("http_requests_total", labels=labels))
|
||||
|
||||
# 请求延迟
|
||||
loop.create_task(self.histogram("http_request_duration_seconds", duration, labels=labels))
|
||||
|
||||
# 错误计数
|
||||
if status_code >= 400:
|
||||
loop.create_task(self.increment("http_errors_total", labels=labels))
|
||||
|
||||
def record_token_usage(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cost: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
记录Token使用指标(便捷方法)
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
input_tokens: 输入token数
|
||||
output_tokens: 输出token数
|
||||
cost: 费用
|
||||
"""
|
||||
labels = {"provider": provider, "model": model}
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Token计数
|
||||
loop.create_task(self.increment("tokens_input_total", input_tokens, labels=labels))
|
||||
loop.create_task(self.increment("tokens_output_total", output_tokens, labels=labels))
|
||||
loop.create_task(
|
||||
self.increment("tokens_total", input_tokens + output_tokens, labels=labels)
|
||||
)
|
||||
|
||||
# 费用
|
||||
if cost is not None:
|
||||
loop.create_task(self.increment("usage_cost_total", cost, labels=labels))
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置插件
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
"""
|
||||
self.config.update(config)
|
||||
self.enabled = config.get("enabled", True)
|
||||
self.flush_interval = config.get("flush_interval", self.flush_interval)
|
||||
self.batch_size = config.get("batch_size", self.batch_size)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(name={self.name}, enabled={self.enabled})>"
|
||||
320
src/plugins/monitor/prometheus.py
Normal file
320
src/plugins/monitor/prometheus.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Prometheus监控插件
|
||||
支持将指标导出到Prometheus
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from prometheus_client import REGISTRY, Counter, Gauge, Histogram, Summary, generate_latest
|
||||
|
||||
PROMETHEUS_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Prometheus client not installed, plugin will be disabled
|
||||
PROMETHEUS_AVAILABLE = False
|
||||
Counter = Gauge = Histogram = Summary = REGISTRY = generate_latest = None
|
||||
|
||||
from .base import Metric, MetricType, MonitorPlugin
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class PrometheusPlugin(MonitorPlugin):
|
||||
"""
|
||||
Prometheus监控插件
|
||||
使用prometheus_client库导出指标
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "prometheus", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
|
||||
# Check if prometheus_client is available
|
||||
if not PROMETHEUS_AVAILABLE:
|
||||
self.enabled = False
|
||||
logger.warning("Prometheus client not installed, plugin disabled")
|
||||
return
|
||||
|
||||
# 指标注册表
|
||||
self._metrics: Dict[str, Any] = {}
|
||||
self._buffer: List[Metric] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: Optional[asyncio.Task] = None # 跟踪后台任务
|
||||
|
||||
# 预定义常用指标
|
||||
self._init_default_metrics()
|
||||
|
||||
# 启动刷新任务
|
||||
self._start_flush_task()
|
||||
|
||||
def _init_default_metrics(self):
|
||||
"""初始化默认指标"""
|
||||
# HTTP请求指标
|
||||
http_label_names = ["method", "endpoint", "status", "status_class"]
|
||||
|
||||
self._metrics["http_requests_total"] = Counter(
|
||||
"http_requests_total",
|
||||
"Total HTTP requests",
|
||||
http_label_names,
|
||||
)
|
||||
|
||||
self._metrics["http_request_duration_seconds"] = Histogram(
|
||||
"http_request_duration_seconds",
|
||||
"HTTP request duration in seconds",
|
||||
http_label_names,
|
||||
buckets=(0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10),
|
||||
)
|
||||
|
||||
self._metrics["http_errors_total"] = Counter(
|
||||
"http_errors_total",
|
||||
"Total HTTP errors",
|
||||
http_label_names,
|
||||
)
|
||||
|
||||
# Token使用指标
|
||||
self._metrics["tokens_input_total"] = Counter(
|
||||
"tokens_input_total", "Total input tokens", ["provider", "model"]
|
||||
)
|
||||
|
||||
self._metrics["tokens_output_total"] = Counter(
|
||||
"tokens_output_total", "Total output tokens", ["provider", "model"]
|
||||
)
|
||||
|
||||
self._metrics["tokens_total"] = Counter(
|
||||
"tokens_total", "Total tokens", ["provider", "model"]
|
||||
)
|
||||
|
||||
self._metrics["usage_cost_total"] = Counter(
|
||||
"usage_cost_total", "Total usage cost in USD", ["provider", "model"]
|
||||
)
|
||||
|
||||
# 系统指标
|
||||
self._metrics["active_connections"] = Gauge(
|
||||
"active_connections", "Number of active connections"
|
||||
)
|
||||
|
||||
self._metrics["cache_hits_total"] = Counter(
|
||||
"cache_hits_total", "Total cache hits", ["cache_type"]
|
||||
)
|
||||
|
||||
self._metrics["cache_misses_total"] = Counter(
|
||||
"cache_misses_total", "Total cache misses", ["cache_type"]
|
||||
)
|
||||
|
||||
# 提供商健康指标
|
||||
self._metrics["provider_health"] = Gauge(
|
||||
"provider_health", "Provider health status (1=healthy, 0=unhealthy)", ["provider"]
|
||||
)
|
||||
|
||||
self._metrics["provider_latency_seconds"] = Histogram(
|
||||
"provider_latency_seconds",
|
||||
"Provider response latency in seconds",
|
||||
["provider", "model"],
|
||||
buckets=(0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30),
|
||||
)
|
||||
|
||||
def _start_flush_task(self):
|
||||
"""启动定期刷新任务"""
|
||||
|
||||
async def flush_loop():
|
||||
try:
|
||||
while self.enabled:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
if self.enabled: # 再次检查,避免关闭时执行
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
# 任务被取消,正常关闭
|
||||
logger.debug("Prometheus flush task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Prometheus flush loop error: {e}")
|
||||
|
||||
# 保存任务句柄以便后续取消
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._flush_task = loop.create_task(flush_loop())
|
||||
except RuntimeError:
|
||||
# 如果没有运行的事件循环,任务将在后续创建
|
||||
logger.warning("No event loop available for Prometheus flush task")
|
||||
|
||||
def _get_or_create_metric(self, name: str, metric_type: MetricType, labels: List[str] = None):
|
||||
"""获取或创建指标"""
|
||||
if name not in self._metrics:
|
||||
labels = labels or []
|
||||
if metric_type == MetricType.COUNTER:
|
||||
self._metrics[name] = Counter(name, f"Auto-created counter {name}", labels)
|
||||
elif metric_type == MetricType.GAUGE:
|
||||
self._metrics[name] = Gauge(name, f"Auto-created gauge {name}", labels)
|
||||
elif metric_type == MetricType.HISTOGRAM:
|
||||
self._metrics[name] = Histogram(name, f"Auto-created histogram {name}", labels)
|
||||
elif metric_type == MetricType.SUMMARY:
|
||||
self._metrics[name] = Summary(name, f"Auto-created summary {name}", labels)
|
||||
|
||||
return self._metrics[name]
|
||||
|
||||
async def record_metric(self, metric: Metric):
|
||||
"""记录单个指标"""
|
||||
async with self._lock:
|
||||
self._buffer.append(metric)
|
||||
|
||||
# 如果缓冲区满,自动刷新
|
||||
if len(self._buffer) >= self.batch_size:
|
||||
await self.flush()
|
||||
|
||||
async def record_batch(self, metrics: List[Metric]):
|
||||
"""批量记录指标"""
|
||||
async with self._lock:
|
||||
self._buffer.extend(metrics)
|
||||
|
||||
# 如果缓冲区满,自动刷新
|
||||
if len(self._buffer) >= self.batch_size:
|
||||
await self.flush()
|
||||
|
||||
async def increment(self, name: str, value: float = 1, labels: Optional[Dict[str, str]] = None):
|
||||
"""增加计数器"""
|
||||
try:
|
||||
if name in self._metrics:
|
||||
metric = self._metrics[name]
|
||||
if labels:
|
||||
# 过滤掉不存在的标签
|
||||
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
|
||||
metric.labels(**filtered_labels).inc(value)
|
||||
else:
|
||||
metric.inc(value)
|
||||
else:
|
||||
# 创建新的计数器
|
||||
label_names = list(labels.keys()) if labels else []
|
||||
metric = self._get_or_create_metric(name, MetricType.COUNTER, label_names)
|
||||
if labels:
|
||||
metric.labels(**labels).inc(value)
|
||||
else:
|
||||
metric.inc(value)
|
||||
except Exception as e:
|
||||
# 记录错误但不中断
|
||||
logger.warning(f"Error recording metric {name}: {e}")
|
||||
|
||||
async def gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""设置仪表值"""
|
||||
try:
|
||||
if name in self._metrics:
|
||||
metric = self._metrics[name]
|
||||
if labels:
|
||||
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
|
||||
metric.labels(**filtered_labels).set(value)
|
||||
else:
|
||||
metric.set(value)
|
||||
else:
|
||||
# 创建新的仪表
|
||||
label_names = list(labels.keys()) if labels else []
|
||||
metric = self._get_or_create_metric(name, MetricType.GAUGE, label_names)
|
||||
if labels:
|
||||
metric.labels(**labels).set(value)
|
||||
else:
|
||||
metric.set(value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error recording gauge {name}: {e}")
|
||||
|
||||
async def histogram(
|
||||
self,
|
||||
name: str,
|
||||
value: float,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
buckets: Optional[List[float]] = None,
|
||||
):
|
||||
"""记录直方图数据"""
|
||||
try:
|
||||
if name in self._metrics:
|
||||
metric = self._metrics[name]
|
||||
if labels:
|
||||
filtered_labels = {k: v for k, v in labels.items() if k in metric._labelnames}
|
||||
metric.labels(**filtered_labels).observe(value)
|
||||
else:
|
||||
metric.observe(value)
|
||||
else:
|
||||
# 创建新的直方图
|
||||
label_names = list(labels.keys()) if labels else []
|
||||
if buckets:
|
||||
metric = Histogram(
|
||||
name, f"Auto-created histogram {name}", label_names, buckets=buckets
|
||||
)
|
||||
else:
|
||||
metric = self._get_or_create_metric(name, MetricType.HISTOGRAM, label_names)
|
||||
self._metrics[name] = metric
|
||||
|
||||
if labels:
|
||||
metric.labels(**labels).observe(value)
|
||||
else:
|
||||
metric.observe(value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error recording histogram {name}: {e}")
|
||||
|
||||
async def timing(self, name: str, duration: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""记录时间指标"""
|
||||
# 使用直方图记录时间
|
||||
await self.histogram(f"{name}_seconds", duration, labels)
|
||||
|
||||
async def flush(self):
|
||||
"""刷新缓冲的指标到Prometheus"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return
|
||||
|
||||
# 处理缓冲区中的指标
|
||||
for metric in self._buffer:
|
||||
if metric.metric_type == MetricType.COUNTER:
|
||||
await self.increment(metric.name, metric.value, metric.labels)
|
||||
elif metric.metric_type == MetricType.GAUGE:
|
||||
await self.gauge(metric.name, metric.value, metric.labels)
|
||||
elif metric.metric_type == MetricType.HISTOGRAM:
|
||||
await self.histogram(metric.name, metric.value, metric.labels)
|
||||
|
||||
# 清空缓冲区
|
||||
self._buffer.clear()
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取插件统计信息"""
|
||||
return {
|
||||
"type": "prometheus",
|
||||
"metrics_count": len(self._metrics),
|
||||
"buffer_size": len(self._buffer),
|
||||
"flush_interval": self.flush_interval,
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
|
||||
def get_metrics(self) -> bytes:
|
||||
"""
|
||||
获取Prometheus格式的指标数据
|
||||
|
||||
Returns:
|
||||
Prometheus文本格式的指标
|
||||
"""
|
||||
return generate_latest(REGISTRY)
|
||||
|
||||
async def shutdown(self):
|
||||
"""
|
||||
关闭插件,取消后台任务
|
||||
|
||||
这个方法应该在应用关闭时调用
|
||||
"""
|
||||
# 禁用插件
|
||||
self.enabled = False
|
||||
|
||||
# 取消并等待后台任务完成
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 最后一次刷新缓冲区
|
||||
await self.flush()
|
||||
|
||||
logger.info("Prometheus plugin shutdown complete")
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
清理资源(别名方法)
|
||||
"""
|
||||
await self.shutdown()
|
||||
15
src/plugins/notification/__init__.py
Normal file
15
src/plugins/notification/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
通知插件
|
||||
"""
|
||||
|
||||
from .base import Notification, NotificationLevel, NotificationPlugin
|
||||
from .email import EmailNotificationPlugin
|
||||
from .webhook import WebhookNotificationPlugin
|
||||
|
||||
__all__ = [
|
||||
"NotificationPlugin",
|
||||
"NotificationLevel",
|
||||
"Notification",
|
||||
"WebhookNotificationPlugin",
|
||||
"EmailNotificationPlugin",
|
||||
]
|
||||
414
src/plugins/notification/base.py
Normal file
414
src/plugins/notification/base.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
通知插件基类
|
||||
定义通知的接口和数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.plugins.common import BasePlugin
|
||||
|
||||
|
||||
class NotificationLevel(Enum):
|
||||
"""通知级别"""
|
||||
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class Notification:
|
||||
"""通知对象"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
level: NotificationLevel = NotificationLevel.INFO,
|
||||
notification_type: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
recipient: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
self.title = title
|
||||
self.message = message
|
||||
self.level = level
|
||||
self.notification_type = notification_type or "system"
|
||||
self.source = source or "aether"
|
||||
self.timestamp = timestamp or datetime.now(timezone.utc)
|
||||
self.metadata = metadata or {}
|
||||
self.recipient = recipient
|
||||
self.tags = tags or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"title": self.title,
|
||||
"message": self.message,
|
||||
"level": self.level.value,
|
||||
"type": self.notification_type,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
"recipient": self.recipient,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""转换为JSON"""
|
||||
return json.dumps(self.to_dict(), default=str)
|
||||
|
||||
def format_message(self, template: Optional[str] = None) -> str:
|
||||
"""格式化消息"""
|
||||
if template:
|
||||
return template.format(
|
||||
title=self.title,
|
||||
message=self.message,
|
||||
level=self.level.value,
|
||||
type=self.notification_type,
|
||||
source=self.source,
|
||||
timestamp=self.timestamp.isoformat(),
|
||||
**self.metadata,
|
||||
)
|
||||
else:
|
||||
# 默认格式
|
||||
return f"[{self.level.value.upper()}] {self.title}\n{self.message}"
|
||||
|
||||
|
||||
class NotificationPlugin(BasePlugin):
|
||||
"""
|
||||
通知插件基类
|
||||
所有通知插件必须实现这个接口
|
||||
|
||||
提供统一的重试机制,子类只需实现 _do_send 和 _do_send_batch 方法
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "notification", config: Dict[str, Any] = None):
|
||||
# 调用父类初始化,设置metadata
|
||||
super().__init__(
|
||||
name=name, config=config, description="Notification Plugin", version="1.0.0"
|
||||
)
|
||||
|
||||
self.min_level = NotificationLevel[self.config.get("min_level", "INFO").upper()]
|
||||
self.batch_size = self.config.get("batch_size", 10)
|
||||
self.flush_interval = self.config.get("flush_interval", 60) # 秒
|
||||
self.retry_count = self.config.get("retry_count", 3)
|
||||
self.retry_delay = self.config.get("retry_delay", 5) # 秒
|
||||
self.retry_backoff = self.config.get("retry_backoff", 2.0) # 指数退避因子
|
||||
|
||||
# 统计信息
|
||||
self._send_attempts = 0
|
||||
self._send_successes = 0
|
||||
self._send_failures = 0
|
||||
self._retry_total = 0
|
||||
|
||||
async def send(self, notification: Notification) -> bool:
|
||||
"""
|
||||
发送单个通知(带重试机制)
|
||||
|
||||
Args:
|
||||
notification: 通知对象
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if not self.should_send(notification):
|
||||
return False
|
||||
|
||||
self._send_attempts += 1
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.retry_count):
|
||||
try:
|
||||
result = await self._do_send(notification)
|
||||
if result:
|
||||
self._send_successes += 1
|
||||
return True
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# 如果不是最后一次尝试,等待后重试
|
||||
if attempt < self.retry_count - 1:
|
||||
self._retry_total += 1
|
||||
delay = self.retry_delay * (self.retry_backoff**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 所有重试都失败
|
||||
self._send_failures += 1
|
||||
if last_error:
|
||||
# 可以在这里记录日志,但不抛出异常
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
async def _do_send(self, notification: Notification) -> bool:
|
||||
"""
|
||||
实际发送单个通知(子类实现)
|
||||
|
||||
Args:
|
||||
notification: 通知对象
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
|
||||
"""
|
||||
批量发送通知(带重试机制)
|
||||
|
||||
Args:
|
||||
notifications: 通知列表
|
||||
|
||||
Returns:
|
||||
发送结果统计
|
||||
"""
|
||||
# 过滤应该发送的通知
|
||||
to_send = [n for n in notifications if self.should_send(n)]
|
||||
|
||||
if not to_send:
|
||||
return {"total": 0, "sent": 0, "failed": 0}
|
||||
|
||||
self._send_attempts += len(to_send)
|
||||
last_error = None
|
||||
result = None
|
||||
|
||||
for attempt in range(self.retry_count):
|
||||
try:
|
||||
result = await self._do_send_batch(to_send)
|
||||
if result and result.get("sent", 0) == len(to_send):
|
||||
self._send_successes += result.get("sent", 0)
|
||||
return result
|
||||
elif result:
|
||||
# 部分成功
|
||||
self._send_successes += result.get("sent", 0)
|
||||
self._send_failures += result.get("failed", 0)
|
||||
return result
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# 如果不是最后一次尝试,等待后重试
|
||||
if attempt < self.retry_count - 1:
|
||||
self._retry_total += 1
|
||||
delay = self.retry_delay * (self.retry_backoff**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 所有重试都失败
|
||||
self._send_failures += len(to_send)
|
||||
return {
|
||||
"total": len(to_send),
|
||||
"sent": 0,
|
||||
"failed": len(to_send),
|
||||
"error": str(last_error) if last_error else "Unknown error",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
|
||||
"""
|
||||
实际批量发送通知(子类实现)
|
||||
|
||||
Args:
|
||||
notifications: 通知列表
|
||||
|
||||
Returns:
|
||||
发送结果统计 {"total": int, "sent": int, "failed": int}
|
||||
"""
|
||||
pass
|
||||
|
||||
def should_send(self, notification: Notification) -> bool:
|
||||
"""判断是否应该发送通知"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# 级别过滤
|
||||
level_values = {level: i for i, level in enumerate(NotificationLevel)}
|
||||
if level_values[notification.level] < level_values[self.min_level]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def send_error(
|
||||
self,
|
||||
error: Exception,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
recipient: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""发送错误通知"""
|
||||
notification = Notification(
|
||||
title=f"Error: {type(error).__name__}",
|
||||
message=str(error),
|
||||
level=NotificationLevel.ERROR,
|
||||
notification_type="error",
|
||||
metadata=context or {},
|
||||
recipient=recipient,
|
||||
tags=["error", type(error).__name__],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def send_warning(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
recipient: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""发送警告通知"""
|
||||
notification = Notification(
|
||||
title=title,
|
||||
message=message,
|
||||
level=NotificationLevel.WARNING,
|
||||
notification_type="warning",
|
||||
metadata=context or {},
|
||||
recipient=recipient,
|
||||
tags=["warning"],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def send_info(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
recipient: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""发送信息通知"""
|
||||
notification = Notification(
|
||||
title=title,
|
||||
message=message,
|
||||
level=NotificationLevel.INFO,
|
||||
notification_type="info",
|
||||
metadata=context or {},
|
||||
recipient=recipient,
|
||||
tags=["info"],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def send_critical(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
recipient: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""发送严重通知"""
|
||||
notification = Notification(
|
||||
title=title,
|
||||
message=message,
|
||||
level=NotificationLevel.CRITICAL,
|
||||
notification_type="critical",
|
||||
metadata=context or {},
|
||||
recipient=recipient,
|
||||
tags=["critical"],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def send_usage_alert(
|
||||
self,
|
||||
user_id: str,
|
||||
usage_percent: float,
|
||||
limit: int,
|
||||
current: int,
|
||||
resource_type: str = "tokens",
|
||||
) -> bool:
|
||||
"""发送使用量警告"""
|
||||
level = NotificationLevel.INFO
|
||||
if usage_percent >= 90:
|
||||
level = NotificationLevel.CRITICAL
|
||||
elif usage_percent >= 75:
|
||||
level = NotificationLevel.WARNING
|
||||
|
||||
notification = Notification(
|
||||
title=f"Usage Alert: {resource_type.capitalize()}",
|
||||
message=f"User {user_id} has used {usage_percent:.1f}% of their {resource_type} quota ({current}/{limit})",
|
||||
level=level,
|
||||
notification_type="usage_alert",
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"usage_percent": usage_percent,
|
||||
"limit": limit,
|
||||
"current": current,
|
||||
"resource_type": resource_type,
|
||||
},
|
||||
tags=["usage", resource_type],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def send_provider_status(
|
||||
self,
|
||||
provider: str,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
latency: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""发送提供商状态通知"""
|
||||
level = NotificationLevel.INFO
|
||||
if status == "down":
|
||||
level = NotificationLevel.CRITICAL
|
||||
elif status == "degraded":
|
||||
level = NotificationLevel.WARNING
|
||||
|
||||
message = f"Provider {provider} is {status}"
|
||||
if error:
|
||||
message += f": {error}"
|
||||
if latency:
|
||||
message += f" (latency: {latency:.2f}s)"
|
||||
|
||||
notification = Notification(
|
||||
title=f"Provider Status: {provider}",
|
||||
message=message,
|
||||
level=level,
|
||||
notification_type="provider_status",
|
||||
metadata={"provider": provider, "status": status, "error": error, "latency": latency},
|
||||
tags=["provider", status],
|
||||
)
|
||||
return await self.send(notification)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典,包含基础重试统计和子类特定统计
|
||||
"""
|
||||
base_stats = {
|
||||
"plugin_name": self.name,
|
||||
"enabled": self.enabled,
|
||||
"send_attempts": self._send_attempts,
|
||||
"send_successes": self._send_successes,
|
||||
"send_failures": self._send_failures,
|
||||
"retry_total": self._retry_total,
|
||||
"success_rate": (
|
||||
self._send_successes / self._send_attempts * 100 if self._send_attempts > 0 else 0
|
||||
),
|
||||
"config": {
|
||||
"min_level": self.min_level.value,
|
||||
"retry_count": self.retry_count,
|
||||
"retry_delay": self.retry_delay,
|
||||
"retry_backoff": self.retry_backoff,
|
||||
"batch_size": self.batch_size,
|
||||
"flush_interval": self.flush_interval,
|
||||
},
|
||||
}
|
||||
|
||||
# 获取子类特定的统计信息
|
||||
extra_stats = await self._get_extra_stats()
|
||||
if extra_stats:
|
||||
base_stats.update(extra_stats)
|
||||
|
||||
return base_stats
|
||||
|
||||
async def _get_extra_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取子类特定的统计信息(子类可选重写)
|
||||
|
||||
Returns:
|
||||
额外的统计信息
|
||||
"""
|
||||
return {}
|
||||
374
src/plugins/notification/email.py
Normal file
374
src/plugins/notification/email.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
邮件通知插件
|
||||
通过SMTP发送邮件通知
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import aiosmtplib
|
||||
|
||||
AIOSMTPLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOSMTPLIB_AVAILABLE = False
|
||||
aiosmtplib = None
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from .base import Notification, NotificationLevel, NotificationPlugin
|
||||
|
||||
|
||||
|
||||
class EmailNotificationPlugin(NotificationPlugin):
|
||||
"""
|
||||
邮件通知插件
|
||||
支持HTML和纯文本邮件
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "email", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
|
||||
# SMTP配置
|
||||
self.smtp_host = config.get("smtp_host") if config else None
|
||||
self.smtp_port = config.get("smtp_port", 587) if config else 587
|
||||
self.smtp_user = config.get("smtp_user") if config else None
|
||||
self.smtp_password = config.get("smtp_password") if config else None
|
||||
self.use_tls = config.get("use_tls", True) if config else True
|
||||
self.use_ssl = config.get("use_ssl", False) if config else False
|
||||
|
||||
# 邮件配置
|
||||
self.from_email = config.get("from_email") if config else None
|
||||
self.from_name = (
|
||||
config.get("from_name", "Aether") if config else "Aether"
|
||||
)
|
||||
self.to_emails = config.get("to_emails", []) if config else []
|
||||
self.cc_emails = config.get("cc_emails", []) if config else []
|
||||
self.bcc_emails = config.get("bcc_emails", []) if config else []
|
||||
|
||||
# 模板配置
|
||||
self.use_html = config.get("use_html", True) if config else True
|
||||
self.subject_prefix = (
|
||||
config.get("subject_prefix", "[Aether]") if config else "[Aether]"
|
||||
)
|
||||
|
||||
# 缓冲配置
|
||||
self._buffer: List[Notification] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task = None
|
||||
|
||||
# 验证配置
|
||||
config_errors = []
|
||||
if not self.smtp_host:
|
||||
config_errors.append("缺少 smtp_host")
|
||||
if not self.from_email:
|
||||
config_errors.append("缺少 from_email")
|
||||
if not self.to_emails:
|
||||
config_errors.append("缺少 to_emails")
|
||||
|
||||
if config_errors:
|
||||
self.enabled = False
|
||||
for error in config_errors:
|
||||
logger.warning(f"Email 插件配置错误: {error},插件已禁用")
|
||||
return
|
||||
|
||||
# 注意: 不在这里启动刷新任务,因为可能还没有运行的事件循环
|
||||
# 需要在应用启动后调用 initialize() 方法来启动任务
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
初始化插件(在事件循环运行后调用)
|
||||
启动后台任务等需要事件循环的操作
|
||||
|
||||
Returns:
|
||||
初始化成功返回 True,失败返回 False
|
||||
"""
|
||||
if not self.enabled:
|
||||
# 配置无效,插件被禁用
|
||||
return False
|
||||
|
||||
if self._flush_task is None:
|
||||
self._start_flush_task()
|
||||
|
||||
return True
|
||||
|
||||
def _start_flush_task(self):
|
||||
"""启动定时刷新任务"""
|
||||
|
||||
async def flush_loop():
|
||||
while self.enabled:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self.flush()
|
||||
|
||||
try:
|
||||
# 获取当前运行的事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
self._flush_task = loop.create_task(flush_loop())
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,任务将在 initialize() 中创建
|
||||
logger.warning("Email 插件刷新任务等待事件循环创建")
|
||||
pass
|
||||
|
||||
def _format_html_email(self, notifications: List[Notification]) -> str:
|
||||
"""格式化HTML邮件"""
|
||||
# 颜色映射
|
||||
color_map = {
|
||||
NotificationLevel.INFO: "#28a745",
|
||||
NotificationLevel.WARNING: "#ffc107",
|
||||
NotificationLevel.ERROR: "#dc3545",
|
||||
NotificationLevel.CRITICAL: "#721c24",
|
||||
}
|
||||
|
||||
# 构建HTML
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; line-height: 1.6; }
|
||||
.notification { margin: 20px 0; padding: 15px; border-left: 5px solid; }
|
||||
.info { border-left-color: #28a745; background-color: #d4edda; }
|
||||
.warning { border-left-color: #ffc107; background-color: #fff3cd; }
|
||||
.error { border-left-color: #dc3545; background-color: #f8d7da; }
|
||||
.critical { border-left-color: #721c24; background-color: #f8d7da; }
|
||||
.title { font-weight: bold; font-size: 1.2em; margin-bottom: 10px; }
|
||||
.metadata { margin-top: 10px; font-size: 0.9em; color: #666; }
|
||||
.timestamp { font-size: 0.8em; color: #999; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>Notifications from Aether</h2>
|
||||
"""
|
||||
|
||||
for notification in notifications:
|
||||
level_class = notification.level.value
|
||||
html += f"""
|
||||
<div class="notification {level_class}">
|
||||
<div class="title">{notification.title}</div>
|
||||
<div class="message">{notification.message}</div>
|
||||
"""
|
||||
|
||||
if notification.metadata:
|
||||
html += '<div class="metadata">'
|
||||
for key, value in notification.metadata.items():
|
||||
html += f"<strong>{key}:</strong> {value}<br>"
|
||||
html += "</div>"
|
||||
|
||||
html += f"""
|
||||
<div class="timestamp">{notification.timestamp.isoformat()}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return html
|
||||
|
||||
def _format_text_email(self, notifications: List[Notification]) -> str:
|
||||
"""格式化纯文本邮件"""
|
||||
lines = ["Notifications from Aether", "=" * 50, ""]
|
||||
|
||||
for notification in notifications:
|
||||
lines.append(f"[{notification.level.value.upper()}] {notification.title}")
|
||||
lines.append("-" * 40)
|
||||
lines.append(notification.message)
|
||||
|
||||
if notification.metadata:
|
||||
lines.append("")
|
||||
for key, value in notification.metadata.items():
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
lines.append(f"\nTime: {notification.timestamp.isoformat()}")
|
||||
lines.append("=" * 50)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _send_email_async(self, subject: str, body: str, is_html: bool = True) -> bool:
|
||||
"""异步发送邮件"""
|
||||
if AIOSMTPLIB_AVAILABLE:
|
||||
# 使用异步SMTP
|
||||
message = MIMEMultipart("alternative")
|
||||
message["Subject"] = f"{self.subject_prefix} {subject}"
|
||||
message["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
message["To"] = ", ".join(self.to_emails)
|
||||
|
||||
if self.cc_emails:
|
||||
message["Cc"] = ", ".join(self.cc_emails)
|
||||
|
||||
# 添加内容
|
||||
if is_html:
|
||||
message.attach(MIMEText(body, "html"))
|
||||
else:
|
||||
message.attach(MIMEText(body, "plain"))
|
||||
|
||||
try:
|
||||
# 发送邮件
|
||||
if self.use_ssl:
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=self.smtp_host,
|
||||
port=self.smtp_port,
|
||||
use_tls=True,
|
||||
username=self.smtp_user,
|
||||
password=self.smtp_password,
|
||||
)
|
||||
else:
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=self.smtp_host,
|
||||
port=self.smtp_port,
|
||||
start_tls=self.use_tls,
|
||||
username=self.smtp_user,
|
||||
password=self.smtp_password,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"异步邮件发送失败: {e}")
|
||||
return False
|
||||
else:
|
||||
# 使用同步SMTP(在线程中运行)
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._send_email_sync, subject, body, is_html
|
||||
)
|
||||
|
||||
def _send_email_sync(self, subject: str, body: str, is_html: bool = True) -> bool:
|
||||
"""同步发送邮件"""
|
||||
message = MIMEMultipart("alternative")
|
||||
message["Subject"] = f"{self.subject_prefix} {subject}"
|
||||
message["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
message["To"] = ", ".join(self.to_emails)
|
||||
|
||||
if self.cc_emails:
|
||||
message["Cc"] = ", ".join(self.cc_emails)
|
||||
|
||||
# 添加内容
|
||||
if is_html:
|
||||
message.attach(MIMEText(body, "html"))
|
||||
else:
|
||||
message.attach(MIMEText(body, "plain"))
|
||||
|
||||
try:
|
||||
# 连接SMTP服务器
|
||||
if self.use_ssl:
|
||||
server = smtplib.SMTP_SSL(self.smtp_host, self.smtp_port)
|
||||
else:
|
||||
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
|
||||
if self.use_tls:
|
||||
server.starttls()
|
||||
|
||||
# 登录
|
||||
if self.smtp_user and self.smtp_password:
|
||||
server.login(self.smtp_user, self.smtp_password)
|
||||
|
||||
# 发送邮件
|
||||
all_recipients = self.to_emails + self.cc_emails + self.bcc_emails
|
||||
server.send_message(message, to_addrs=all_recipients)
|
||||
server.quit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步邮件发送失败: {e}")
|
||||
return False
|
||||
|
||||
async def _do_send(self, notification: Notification) -> bool:
|
||||
"""
|
||||
实际发送单个通知
|
||||
|
||||
Note: 对于 CRITICAL 级别通知,直接发送;其他级别加入缓冲区
|
||||
"""
|
||||
# 添加到缓冲区
|
||||
async with self._lock:
|
||||
self._buffer.append(notification)
|
||||
|
||||
# 如果是严重通知,立即发送
|
||||
if notification.level == NotificationLevel.CRITICAL:
|
||||
return await self._flush_buffer()
|
||||
|
||||
# 如果缓冲区满,自动刷新
|
||||
if len(self._buffer) >= self.batch_size:
|
||||
return await self._flush_buffer()
|
||||
|
||||
return True
|
||||
|
||||
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
|
||||
"""实际批量发送通知"""
|
||||
if not notifications:
|
||||
return {"total": 0, "sent": 0, "failed": 0}
|
||||
|
||||
# 准备邮件内容
|
||||
subject = f"Batch Notifications ({len(notifications)} items)"
|
||||
|
||||
# 检查是否有严重通知
|
||||
critical_count = sum(1 for n in notifications if n.level == NotificationLevel.CRITICAL)
|
||||
if critical_count > 0:
|
||||
subject = f"[CRITICAL] {subject}"
|
||||
|
||||
# 格式化邮件内容
|
||||
if self.use_html:
|
||||
body = self._format_html_email(notifications)
|
||||
else:
|
||||
body = self._format_text_email(notifications)
|
||||
|
||||
# 发送邮件
|
||||
success = await self._send_email_async(subject, body, self.use_html)
|
||||
|
||||
return {
|
||||
"total": len(notifications),
|
||||
"sent": len(notifications) if success else 0,
|
||||
"failed": 0 if success else len(notifications),
|
||||
}
|
||||
|
||||
async def _flush_buffer(self) -> bool:
|
||||
"""刷新缓冲的通知(内部方法,不带锁)"""
|
||||
if not self._buffer:
|
||||
return True
|
||||
|
||||
notifications = self._buffer[:]
|
||||
self._buffer.clear()
|
||||
|
||||
# 批量发送(直接调用 _do_send_batch 避免重复统计)
|
||||
result = await self._do_send_batch(notifications)
|
||||
return result["failed"] == 0
|
||||
|
||||
async def flush(self) -> bool:
|
||||
"""刷新缓冲的通知"""
|
||||
async with self._lock:
|
||||
return await self._flush_buffer()
|
||||
|
||||
async def _get_extra_stats(self) -> Dict[str, Any]:
|
||||
"""获取 Email 特定的统计信息"""
|
||||
return {
|
||||
"type": "email",
|
||||
"smtp_host": self.smtp_host,
|
||||
"smtp_port": self.smtp_port,
|
||||
"from_email": self.from_email,
|
||||
"recipients_count": len(self.to_emails),
|
||||
"buffer_size": len(self._buffer),
|
||||
"use_html": self.use_html,
|
||||
"aiosmtplib_available": AIOSMTPLIB_AVAILABLE,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""关闭插件"""
|
||||
# 刷新缓冲
|
||||
await self.flush()
|
||||
|
||||
# 取消刷新任务
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
try:
|
||||
asyncio.create_task(self.close())
|
||||
except:
|
||||
pass
|
||||
309
src/plugins/notification/webhook.py
Normal file
309
src/plugins/notification/webhook.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Webhook通知插件
|
||||
通过HTTP Webhook发送通知
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
aiohttp = None
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from .base import Notification, NotificationLevel, NotificationPlugin
|
||||
|
||||
|
||||
class WebhookNotificationPlugin(NotificationPlugin):
|
||||
"""
|
||||
Webhook通知插件
|
||||
支持多种Webhook格式(Slack, Discord, 通用)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "webhook", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
self.enabled = False
|
||||
logger.warning("aiohttp not installed, webhook plugin disabled")
|
||||
return
|
||||
|
||||
# Webhook配置
|
||||
self.webhook_url = config.get("webhook_url") if config else None
|
||||
self.webhook_type = (
|
||||
config.get("webhook_type", "generic") if config else "generic"
|
||||
) # generic, slack, discord, teams
|
||||
self.secret = config.get("secret") if config else None # 用于签名
|
||||
self.timeout = config.get("timeout", 30) if config else 30
|
||||
self.headers = config.get("headers", {}) if config else {}
|
||||
|
||||
# 缓冲配置
|
||||
self._buffer: List[Notification] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._flush_task = None
|
||||
|
||||
if not self.webhook_url:
|
||||
self.enabled = False
|
||||
logger.warning("No webhook URL configured")
|
||||
return
|
||||
|
||||
# 启动刷新任务
|
||||
self._start_flush_task()
|
||||
|
||||
def _start_flush_task(self):
|
||||
"""启动定时刷新任务"""
|
||||
|
||||
async def flush_loop():
|
||||
while self.enabled:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self.flush()
|
||||
|
||||
self._flush_task = asyncio.create_task(flush_loop())
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取HTTP会话"""
|
||||
if not self._session:
|
||||
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout))
|
||||
return self._session
|
||||
|
||||
def _generate_signature(self, payload: str) -> str:
|
||||
"""生成请求签名"""
|
||||
if not self.secret:
|
||||
return ""
|
||||
|
||||
# 使用HMAC-SHA256生成签名
|
||||
signature = hmac.new(
|
||||
self.secret.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return signature
|
||||
|
||||
def _format_for_slack(self, notification: Notification) -> Dict[str, Any]:
|
||||
"""格式化为Slack消息"""
|
||||
# Slack颜色映射
|
||||
color_map = {
|
||||
NotificationLevel.INFO: "#36a64f",
|
||||
NotificationLevel.WARNING: "warning",
|
||||
NotificationLevel.ERROR: "danger",
|
||||
NotificationLevel.CRITICAL: "#ff0000",
|
||||
}
|
||||
|
||||
return {
|
||||
"text": notification.title,
|
||||
"attachments": [
|
||||
{
|
||||
"color": color_map.get(notification.level, "#808080"),
|
||||
"title": notification.title,
|
||||
"text": notification.message,
|
||||
"fields": (
|
||||
[
|
||||
{"title": k, "value": str(v), "short": True}
|
||||
for k, v in notification.metadata.items()
|
||||
]
|
||||
if notification.metadata
|
||||
else []
|
||||
),
|
||||
"footer": notification.source,
|
||||
"ts": int(notification.timestamp.timestamp()),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _format_for_discord(self, notification: Notification) -> Dict[str, Any]:
|
||||
"""格式化为Discord消息"""
|
||||
# Discord颜色映射
|
||||
color_map = {
|
||||
NotificationLevel.INFO: 0x00FF00,
|
||||
NotificationLevel.WARNING: 0xFFA500,
|
||||
NotificationLevel.ERROR: 0xFF0000,
|
||||
NotificationLevel.CRITICAL: 0x8B0000,
|
||||
}
|
||||
|
||||
embeds = [
|
||||
{
|
||||
"title": notification.title,
|
||||
"description": notification.message,
|
||||
"color": color_map.get(notification.level, 0x808080),
|
||||
"fields": (
|
||||
[
|
||||
{"name": k, "value": str(v), "inline": True}
|
||||
for k, v in notification.metadata.items()
|
||||
]
|
||||
if notification.metadata
|
||||
else []
|
||||
),
|
||||
"footer": {"text": notification.source},
|
||||
"timestamp": notification.timestamp.isoformat(),
|
||||
}
|
||||
]
|
||||
|
||||
return {"embeds": embeds}
|
||||
|
||||
def _format_for_teams(self, notification: Notification) -> Dict[str, Any]:
|
||||
"""格式化为Microsoft Teams消息"""
|
||||
# Teams颜色映射
|
||||
color_map = {
|
||||
NotificationLevel.INFO: "00ff00",
|
||||
NotificationLevel.WARNING: "ffa500",
|
||||
NotificationLevel.ERROR: "ff0000",
|
||||
NotificationLevel.CRITICAL: "8b0000",
|
||||
}
|
||||
|
||||
facts = (
|
||||
[{"name": k, "value": str(v)} for k, v in notification.metadata.items()]
|
||||
if notification.metadata
|
||||
else []
|
||||
)
|
||||
|
||||
return {
|
||||
"@type": "MessageCard",
|
||||
"@context": "https://schema.org/extensions",
|
||||
"themeColor": color_map.get(notification.level, "808080"),
|
||||
"title": notification.title,
|
||||
"text": notification.message,
|
||||
"sections": [{"facts": facts}] if facts else [],
|
||||
"summary": notification.title,
|
||||
}
|
||||
|
||||
def _format_payload(self, notification: Notification) -> Dict[str, Any]:
|
||||
"""根据Webhook类型格式化负载"""
|
||||
if self.webhook_type == "slack":
|
||||
return self._format_for_slack(notification)
|
||||
elif self.webhook_type == "discord":
|
||||
return self._format_for_discord(notification)
|
||||
elif self.webhook_type == "teams":
|
||||
return self._format_for_teams(notification)
|
||||
else:
|
||||
# 通用格式
|
||||
return notification.to_dict()
|
||||
|
||||
async def _do_send(self, notification: Notification) -> bool:
|
||||
"""
|
||||
实际发送单个通知
|
||||
|
||||
Note: 对于 CRITICAL 级别通知,直接发送;其他级别加入缓冲区
|
||||
"""
|
||||
# 添加到缓冲区
|
||||
async with self._lock:
|
||||
self._buffer.append(notification)
|
||||
|
||||
# 如果是严重通知,立即发送
|
||||
if notification.level == NotificationLevel.CRITICAL:
|
||||
return await self._flush_buffer()
|
||||
|
||||
# 如果缓冲区满,自动刷新
|
||||
if len(self._buffer) >= self.batch_size:
|
||||
return await self._flush_buffer()
|
||||
|
||||
return True
|
||||
|
||||
async def _do_send_batch(self, notifications: List[Notification]) -> Dict[str, Any]:
|
||||
"""实际批量发送通知"""
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
errors = []
|
||||
|
||||
if not notifications:
|
||||
return {"total": 0, "sent": 0, "failed": 0}
|
||||
|
||||
# 批量发送
|
||||
for notification in notifications:
|
||||
try:
|
||||
payload = self._format_payload(notification)
|
||||
payload_str = json.dumps(payload)
|
||||
|
||||
headers = dict(self.headers)
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# 添加签名
|
||||
if self.secret:
|
||||
signature = self._generate_signature(payload_str)
|
||||
headers["X-Signature"] = signature
|
||||
headers["X-Timestamp"] = str(int(time.time()))
|
||||
|
||||
# 发送请求
|
||||
session = await self._get_session()
|
||||
async with session.post(
|
||||
self.webhook_url, data=payload_str, headers=headers
|
||||
) as response:
|
||||
if response.status < 300:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
error_text = await response.text()
|
||||
errors.append(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
errors.append(str(e))
|
||||
|
||||
return {
|
||||
"total": len(notifications),
|
||||
"sent": success_count,
|
||||
"failed": failed_count,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
async def _flush_buffer(self) -> bool:
|
||||
"""刷新缓冲的通知(内部方法,不带锁)"""
|
||||
if not self._buffer:
|
||||
return True
|
||||
|
||||
notifications = self._buffer[:]
|
||||
self._buffer.clear()
|
||||
|
||||
# 批量发送(直接调用 _do_send_batch 避免重复统计)
|
||||
result = await self._do_send_batch(notifications)
|
||||
return result["failed"] == 0
|
||||
|
||||
async def flush(self) -> bool:
|
||||
"""刷新缓冲的通知"""
|
||||
async with self._lock:
|
||||
return await self._flush_buffer()
|
||||
|
||||
async def _get_extra_stats(self) -> Dict[str, Any]:
|
||||
"""获取 Webhook 特定的统计信息"""
|
||||
return {
|
||||
"type": "webhook",
|
||||
"webhook_type": self.webhook_type,
|
||||
"webhook_url": (
|
||||
self.webhook_url.split("?")[0] if self.webhook_url else None
|
||||
), # 隐藏查询参数
|
||||
"buffer_size": len(self._buffer),
|
||||
"has_secret": bool(self.secret),
|
||||
}
|
||||
|
||||
async def _do_shutdown(self):
|
||||
"""清理资源"""
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""关闭插件"""
|
||||
# 刷新缓冲
|
||||
await self.flush()
|
||||
|
||||
# 取消刷新任务
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
|
||||
# 关闭HTTP会话
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
try:
|
||||
asyncio.create_task(self.close())
|
||||
except:
|
||||
pass
|
||||
9
src/plugins/rate_limit/__init__.py
Normal file
9
src/plugins/rate_limit/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
速率限制插件模块
|
||||
"""
|
||||
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
from .sliding_window import SlidingWindowStrategy
|
||||
from .token_bucket import TokenBucketStrategy
|
||||
|
||||
__all__ = ["RateLimitStrategy", "RateLimitResult", "TokenBucketStrategy", "SlidingWindowStrategy"]
|
||||
132
src/plugins/rate_limit/base.py
Normal file
132
src/plugins/rate_limit/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
速率限制策略基类
|
||||
定义速率限制策略的接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..common import BasePlugin, HealthStatus, PluginMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""
|
||||
速率限制检查结果
|
||||
"""
|
||||
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_at: Optional[datetime] = None
|
||||
retry_after: Optional[int] = None
|
||||
message: Optional[str] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.headers is None:
|
||||
self.headers = {}
|
||||
if self.remaining is not None:
|
||||
self.headers["X-RateLimit-Remaining"] = str(self.remaining)
|
||||
if self.reset_at:
|
||||
self.headers["X-RateLimit-Reset"] = str(int(self.reset_at.timestamp()))
|
||||
if self.retry_after:
|
||||
self.headers["Retry-After"] = str(self.retry_after)
|
||||
|
||||
|
||||
class RateLimitStrategy(BasePlugin):
|
||||
"""
|
||||
速率限制策略基类
|
||||
所有速率限制策略必须继承此类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化速率限制策略
|
||||
|
||||
Args:
|
||||
name: 策略名称
|
||||
priority: 优先级(数字越大优先级越高)
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖的其他插件
|
||||
provides: 提供的服务
|
||||
config: 配置字典
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
priority=priority,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies,
|
||||
provides=provides,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键(如用户ID、API Key ID等)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费配额
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
pass
|
||||
363
src/plugins/rate_limit/sliding_window.py
Normal file
363
src/plugins/rate_limit/sliding_window.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
滑动窗口算法速率限制策略
|
||||
精确的速率限制,不允许突发
|
||||
|
||||
WARNING: 多进程环境注意事项
|
||||
=============================
|
||||
此插件的窗口状态存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
|
||||
每个 worker 进程有独立的限流状态,可能导致:
|
||||
- 实际允许的请求数 = 配置限制 * worker数量
|
||||
- 限流效果大打折扣
|
||||
|
||||
解决方案:
|
||||
1. 单 worker 模式:适用于低流量场景
|
||||
2. Redis 共享状态:使用 Redis 实现分布式滑动窗口
|
||||
3. 使用 token_bucket.py:令牌桶策略可以更容易迁移到 Redis
|
||||
|
||||
目前项目已有 Redis 依赖(src/clients/redis_client.py),
|
||||
建议在生产环境使用 Redis 实现分布式限流。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Deque, Dict
|
||||
|
||||
from src.core.logger import logger
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
|
||||
|
||||
|
||||
class SlidingWindow:
|
||||
"""滑动窗口实现"""
|
||||
|
||||
def __init__(self, window_size: int, max_requests: int):
|
||||
"""
|
||||
初始化滑动窗口
|
||||
|
||||
Args:
|
||||
window_size: 窗口大小(秒)
|
||||
max_requests: 窗口内最大请求数
|
||||
"""
|
||||
self.window_size = window_size
|
||||
self.max_requests = max_requests
|
||||
self.requests: Deque[float] = deque()
|
||||
self.last_access_time: float = time.time()
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理过期的请求记录"""
|
||||
current_time = time.time()
|
||||
self.last_access_time = current_time # 更新最后访问时间
|
||||
cutoff_time = current_time - self.window_size
|
||||
|
||||
# 移除窗口外的请求
|
||||
while self.requests and self.requests[0] < cutoff_time:
|
||||
self.requests.popleft()
|
||||
|
||||
def can_accept(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
检查是否可以接受新请求
|
||||
|
||||
Args:
|
||||
amount: 请求数量
|
||||
|
||||
Returns:
|
||||
是否可以接受
|
||||
"""
|
||||
self._cleanup()
|
||||
return len(self.requests) + amount <= self.max_requests
|
||||
|
||||
def add_request(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
添加请求记录
|
||||
|
||||
Args:
|
||||
amount: 请求数量
|
||||
|
||||
Returns:
|
||||
是否成功添加
|
||||
"""
|
||||
if not self.can_accept(amount):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
for _ in range(amount):
|
||||
self.requests.append(current_time)
|
||||
return True
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
"""获取剩余配额"""
|
||||
self._cleanup()
|
||||
return max(0, self.max_requests - len(self.requests))
|
||||
|
||||
def get_reset_time(self) -> datetime:
|
||||
"""获取最早的重置时间"""
|
||||
self._cleanup()
|
||||
if not self.requests:
|
||||
return datetime.now()
|
||||
|
||||
# 最早的请求将在window_size秒后过期
|
||||
oldest_request = self.requests[0]
|
||||
reset_time = oldest_request + self.window_size
|
||||
return datetime.fromtimestamp(reset_time)
|
||||
|
||||
|
||||
class SlidingWindowStrategy(RateLimitStrategy):
|
||||
"""
|
||||
滑动窗口算法速率限制策略
|
||||
|
||||
特点:
|
||||
- 精确的速率限制
|
||||
- 不允许突发流量
|
||||
- 适合需要严格速率控制的场景
|
||||
- 自动清理长时间不活跃的窗口,防止内存泄漏
|
||||
"""
|
||||
|
||||
# 默认最大缓存窗口数量
|
||||
DEFAULT_MAX_WINDOWS = 10000
|
||||
# 默认窗口过期时间(秒)- 超过此时间未访问的窗口将被清理
|
||||
DEFAULT_WINDOW_EXPIRY = 3600 # 1小时
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("sliding_window")
|
||||
self.windows: Dict[str, SlidingWindow] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 默认配置
|
||||
self.default_window_size = 60 # 默认60秒窗口
|
||||
self.default_max_requests = 100 # 默认100个请求
|
||||
|
||||
# 内存管理配置
|
||||
self.max_windows = self.DEFAULT_MAX_WINDOWS
|
||||
self.window_expiry = self.DEFAULT_WINDOW_EXPIRY
|
||||
self._last_cleanup_time: float = time.time()
|
||||
self._cleanup_interval = 300 # 每5分钟检查一次是否需要清理
|
||||
|
||||
def _cleanup_expired_windows(self) -> int:
|
||||
"""
|
||||
清理过期的窗口,防止内存泄漏
|
||||
|
||||
Returns:
|
||||
清理的窗口数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
|
||||
for key, window in self.windows.items():
|
||||
# 检查窗口是否过期(长时间未访问)
|
||||
if current_time - window.last_access_time > self.window_expiry:
|
||||
expired_keys.append(key)
|
||||
|
||||
# 删除过期窗口
|
||||
for key in expired_keys:
|
||||
del self.windows[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的滑动窗口")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def _evict_lru_windows(self, count: int) -> int:
|
||||
"""
|
||||
使用 LRU 策略淘汰最久未使用的窗口
|
||||
|
||||
Args:
|
||||
count: 需要淘汰的数量
|
||||
|
||||
Returns:
|
||||
实际淘汰的数量
|
||||
"""
|
||||
if not self.windows or count <= 0:
|
||||
return 0
|
||||
|
||||
# 按最后访问时间排序,淘汰最久未访问的
|
||||
sorted_keys = sorted(self.windows.keys(), key=lambda k: self.windows[k].last_access_time)
|
||||
|
||||
evicted = 0
|
||||
for key in sorted_keys[:count]:
|
||||
del self.windows[key]
|
||||
evicted += 1
|
||||
|
||||
if evicted:
|
||||
logger.warning(f"LRU 淘汰了 {evicted} 个滑动窗口(达到容量上限)")
|
||||
|
||||
return evicted
|
||||
|
||||
async def _maybe_cleanup(self):
|
||||
"""检查是否需要执行清理操作"""
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期窗口
|
||||
if current_time - self._last_cleanup_time > self._cleanup_interval:
|
||||
self._cleanup_expired_windows()
|
||||
self._last_cleanup_time = current_time
|
||||
|
||||
# 如果超过容量上限,执行 LRU 淘汰
|
||||
if len(self.windows) >= self.max_windows:
|
||||
# 淘汰 10% 的窗口
|
||||
evict_count = max(1, self.max_windows // 10)
|
||||
self._evict_lru_windows(evict_count)
|
||||
|
||||
def _get_window(self, key: str) -> SlidingWindow:
|
||||
"""
|
||||
获取或创建滑动窗口
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
滑动窗口实例
|
||||
"""
|
||||
if key not in self.windows:
|
||||
# 根据key的不同前缀使用不同的配置
|
||||
if key.startswith("api_key:"):
|
||||
window_size = self.config.get("api_key_window_size", self.default_window_size)
|
||||
max_requests = self.config.get("api_key_max_requests", self.default_max_requests)
|
||||
elif key.startswith("user:"):
|
||||
window_size = self.config.get("user_window_size", self.default_window_size)
|
||||
max_requests = self.config.get("user_max_requests", self.default_max_requests * 2)
|
||||
else:
|
||||
window_size = self.default_window_size
|
||||
max_requests = self.default_max_requests
|
||||
|
||||
self.windows[key] = SlidingWindow(window_size, max_requests)
|
||||
|
||||
return self.windows[key]
|
||||
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
async with self._lock:
|
||||
# 检查是否需要清理过期窗口
|
||||
await self._maybe_cleanup()
|
||||
|
||||
window = self._get_window(key)
|
||||
amount = kwargs.get("amount", 1)
|
||||
|
||||
# 检查是否可以接受请求
|
||||
allowed = window.can_accept(amount)
|
||||
remaining = window.get_remaining()
|
||||
reset_at = window.get_reset_time()
|
||||
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
# 计算需要等待的时间(最早请求过期的时间)
|
||||
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费配额
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
async with self._lock:
|
||||
window = self._get_window(key)
|
||||
success = window.add_request(amount)
|
||||
|
||||
if success:
|
||||
logger.debug(f"滑动窗口请求记录成功")
|
||||
else:
|
||||
logger.warning(f"滑动窗口请求被拒绝:超出速率限制")
|
||||
|
||||
return success
|
||||
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置滑动窗口
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
async with self._lock:
|
||||
if key in self.windows:
|
||||
window = self.windows[key]
|
||||
window.requests.clear()
|
||||
|
||||
logger.info(f"滑动窗口已重置")
|
||||
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
async with self._lock:
|
||||
window = self._get_window(key)
|
||||
window._cleanup() # 先清理过期请求
|
||||
|
||||
return {
|
||||
"strategy": "sliding_window",
|
||||
"key": key,
|
||||
"window_size": window.window_size,
|
||||
"max_requests": window.max_requests,
|
||||
"current_requests": len(window.requests),
|
||||
"remaining": window.get_remaining(),
|
||||
"reset_at": window.get_reset_time().isoformat(),
|
||||
}
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置策略
|
||||
|
||||
支持的配置项:
|
||||
- api_key_window_size: API Key的窗口大小(秒)
|
||||
- api_key_max_requests: API Key的最大请求数
|
||||
- user_window_size: 用户的窗口大小(秒)
|
||||
- user_max_requests: 用户的最大请求数
|
||||
- max_windows: 最大缓存窗口数量(防止内存泄漏)
|
||||
- window_expiry: 窗口过期时间(秒)
|
||||
- cleanup_interval: 清理检查间隔(秒)
|
||||
"""
|
||||
super().configure(config)
|
||||
self.default_window_size = config.get("default_window_size", self.default_window_size)
|
||||
self.default_max_requests = config.get("default_max_requests", self.default_max_requests)
|
||||
self.max_windows = config.get("max_windows", self.max_windows)
|
||||
self.window_expiry = config.get("window_expiry", self.window_expiry)
|
||||
self._cleanup_interval = config.get("cleanup_interval", self._cleanup_interval)
|
||||
|
||||
def get_memory_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取内存使用统计信息
|
||||
|
||||
Returns:
|
||||
内存使用统计
|
||||
"""
|
||||
return {
|
||||
"total_windows": len(self.windows),
|
||||
"max_windows": self.max_windows,
|
||||
"window_expiry": self.window_expiry,
|
||||
"cleanup_interval": self._cleanup_interval,
|
||||
"last_cleanup_time": self._last_cleanup_time,
|
||||
"usage_percent": (
|
||||
(len(self.windows) / self.max_windows * 100) if self.max_windows > 0 else 0
|
||||
),
|
||||
}
|
||||
431
src/plugins/rate_limit/token_bucket.py
Normal file
431
src/plugins/rate_limit/token_bucket.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""令牌桶速率限制策略,支持 Redis 分布式后端"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from ...clients.redis_client import get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""令牌桶实现"""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
"""
|
||||
初始化令牌桶
|
||||
|
||||
Args:
|
||||
capacity: 桶容量(最大令牌数)
|
||||
refill_rate: 令牌补充速率(每秒)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.tokens = capacity
|
||||
self.last_refill = time.time()
|
||||
|
||||
def _refill(self):
|
||||
"""补充令牌"""
|
||||
now = time.time()
|
||||
time_passed = now - self.last_refill
|
||||
tokens_to_add = time_passed * self.refill_rate
|
||||
|
||||
if tokens_to_add > 0:
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
self.last_refill = now
|
||||
|
||||
def consume(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
消费令牌
|
||||
|
||||
Args:
|
||||
amount: 要消费的令牌数
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= amount:
|
||||
self.tokens -= amount
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
"""获取剩余令牌数"""
|
||||
self._refill()
|
||||
return int(self.tokens)
|
||||
|
||||
def get_reset_time(self) -> datetime:
|
||||
"""获取下次完全恢复的时间"""
|
||||
if self.tokens >= self.capacity:
|
||||
return datetime.now()
|
||||
|
||||
tokens_needed = self.capacity - self.tokens
|
||||
seconds_to_full = tokens_needed / self.refill_rate
|
||||
return datetime.now() + timedelta(seconds=seconds_to_full)
|
||||
|
||||
|
||||
class TokenBucketStrategy(RateLimitStrategy):
|
||||
"""
|
||||
令牌桶算法速率限制策略
|
||||
|
||||
特点:
|
||||
- 允许突发流量
|
||||
- 平均速率受限
|
||||
- 适合处理不均匀的流量模式
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("token_bucket")
|
||||
self.buckets: Dict[str, TokenBucket] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 默认配置
|
||||
self.default_capacity = 100 # 默认桶容量
|
||||
self.default_refill_rate = 10 # 默认每秒补充10个令牌
|
||||
|
||||
# 可选的 Redis 后端
|
||||
self._redis_backend: Optional[RedisTokenBucketBackend] = None
|
||||
self._redis_checked = False
|
||||
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
|
||||
|
||||
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
|
||||
"""
|
||||
获取或创建令牌桶
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
|
||||
|
||||
Returns:
|
||||
令牌桶实例
|
||||
"""
|
||||
if key not in self.buckets:
|
||||
# 如果提供了rate_limit参数(来自数据库),优先使用
|
||||
if rate_limit is not None:
|
||||
# rate_limit 是每分钟请求数,转换为令牌桶参数
|
||||
capacity = rate_limit # 桶容量等于每分钟限制
|
||||
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
|
||||
# 否则根据key的不同前缀使用不同的配置
|
||||
elif key.startswith("api_key:"):
|
||||
capacity = self.config.get("api_key_capacity", self.default_capacity)
|
||||
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
|
||||
elif key.startswith("user:"):
|
||||
capacity = self.config.get("user_capacity", self.default_capacity * 2)
|
||||
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
|
||||
else:
|
||||
capacity = self.default_capacity
|
||||
refill_rate = self.default_refill_rate
|
||||
|
||||
self.buckets[key] = TokenBucket(capacity, refill_rate)
|
||||
|
||||
return self.buckets[key]
|
||||
|
||||
def _want_redis_backend(self) -> bool:
|
||||
return self._backend_mode in {"auto", "redis"}
|
||||
|
||||
async def _ensure_backend(self):
|
||||
if self._redis_checked:
|
||||
return
|
||||
self._redis_checked = True
|
||||
if not self._want_redis_backend():
|
||||
return
|
||||
redis_client = get_redis_client_sync()
|
||||
if redis_client:
|
||||
self._redis_backend = RedisTokenBucketBackend(redis_client)
|
||||
logger.info("速率限制改用 Redis 令牌桶后端")
|
||||
elif self._backend_mode == "redis":
|
||||
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
|
||||
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
rate_limit = kwargs.get("rate_limit")
|
||||
amount = kwargs.get("amount", 1)
|
||||
|
||||
if self._redis_backend:
|
||||
return await self._redis_backend.peek(
|
||||
key=key,
|
||||
capacity=self._resolve_capacity(key, rate_limit),
|
||||
refill_rate=self._resolve_refill_rate(key, rate_limit),
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key, rate_limit)
|
||||
remaining = bucket.get_remaining()
|
||||
reset_at = bucket.get_reset_time()
|
||||
|
||||
allowed = remaining >= amount
|
||||
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
tokens_needed = amount - remaining
|
||||
retry_after = int(tokens_needed / bucket.refill_rate) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费令牌
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
success, remaining = await self._redis_backend.consume(
|
||||
key=key,
|
||||
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
|
||||
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
|
||||
amount=amount,
|
||||
)
|
||||
if success:
|
||||
logger.debug("Redis 令牌消费成功")
|
||||
else:
|
||||
logger.warning("Redis 令牌消费失败")
|
||||
return success
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key)
|
||||
success = bucket.consume(amount)
|
||||
|
||||
if success:
|
||||
logger.debug(f"令牌消费成功")
|
||||
else:
|
||||
logger.warning(f"令牌消费失败:超出速率限制")
|
||||
|
||||
return success
|
||||
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置令牌桶
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
await self._redis_backend.reset(key)
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
if key in self.buckets:
|
||||
bucket = self.buckets[key]
|
||||
bucket.tokens = bucket.capacity
|
||||
bucket.last_refill = time.time()
|
||||
|
||||
logger.info(f"令牌桶已重置")
|
||||
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
return await self._redis_backend.get_stats(
|
||||
key,
|
||||
capacity=self._resolve_capacity(key),
|
||||
refill_rate=self._resolve_refill_rate(key),
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key)
|
||||
return {
|
||||
"strategy": "token_bucket",
|
||||
"key": key,
|
||||
"capacity": bucket.capacity,
|
||||
"remaining": bucket.get_remaining(),
|
||||
"refill_rate": bucket.refill_rate,
|
||||
"reset_at": bucket.get_reset_time().isoformat(),
|
||||
}
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置策略
|
||||
|
||||
支持的配置项:
|
||||
- api_key_capacity: API Key的桶容量
|
||||
- api_key_refill_rate: API Key的令牌补充速率
|
||||
- user_capacity: 用户的桶容量
|
||||
- user_refill_rate: 用户的令牌补充速率
|
||||
"""
|
||||
super().configure(config)
|
||||
self.default_capacity = config.get("default_capacity", self.default_capacity)
|
||||
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
|
||||
|
||||
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
|
||||
if rate_limit is not None:
|
||||
return rate_limit
|
||||
if key.startswith("api_key:"):
|
||||
return self.config.get("api_key_capacity", self.default_capacity)
|
||||
if key.startswith("user:"):
|
||||
return self.config.get("user_capacity", self.default_capacity * 2)
|
||||
return self.default_capacity
|
||||
|
||||
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
|
||||
if rate_limit is not None:
|
||||
return rate_limit / 60.0
|
||||
if key.startswith("api_key:"):
|
||||
return self.config.get("api_key_refill_rate", self.default_refill_rate)
|
||||
if key.startswith("user:"):
|
||||
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
|
||||
return self.default_refill_rate
|
||||
|
||||
|
||||
class RedisTokenBucketBackend:
|
||||
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
|
||||
|
||||
_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local capacity = tonumber(ARGV[2])
|
||||
local refill_rate = tonumber(ARGV[3])
|
||||
local amount = tonumber(ARGV[4])
|
||||
|
||||
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
|
||||
local tokens = tonumber(data[1])
|
||||
local last_refill = tonumber(data[2])
|
||||
|
||||
if tokens == nil then
|
||||
tokens = capacity
|
||||
last_refill = now
|
||||
end
|
||||
|
||||
local delta = math.max(0, now - last_refill)
|
||||
local refill = delta * refill_rate
|
||||
tokens = math.min(capacity, tokens + refill)
|
||||
|
||||
local allowed = 0
|
||||
local retry_after = 0
|
||||
if tokens >= amount then
|
||||
tokens = tokens - amount
|
||||
allowed = 1
|
||||
else
|
||||
retry_after = math.ceil((amount - tokens) / refill_rate)
|
||||
end
|
||||
|
||||
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
|
||||
local ttl = math.max(1, math.ceil(capacity / refill_rate))
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return {allowed, tokens, retry_after}
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
self._consume_script = self.redis.register_script(self._SCRIPT)
|
||||
|
||||
def _redis_key(self, key: str) -> str:
|
||||
return f"rate_limit:bucket:{key}"
|
||||
|
||||
async def peek(
|
||||
self,
|
||||
key: str,
|
||||
capacity: int,
|
||||
refill_rate: float,
|
||||
amount: int,
|
||||
) -> RateLimitResult:
|
||||
bucket_key = self._redis_key(key)
|
||||
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
|
||||
tokens = data[0]
|
||||
last_refill = data[1]
|
||||
|
||||
if tokens is None or last_refill is None:
|
||||
remaining = capacity
|
||||
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
|
||||
else:
|
||||
tokens_value = float(tokens)
|
||||
last_refill_value = float(last_refill)
|
||||
delta = max(0.0, time.time() - last_refill_value)
|
||||
tokens_value = min(capacity, tokens_value + delta * refill_rate)
|
||||
remaining = int(tokens_value)
|
||||
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
|
||||
reset_at = datetime.now() + timedelta(seconds=reset_after)
|
||||
|
||||
allowed = remaining >= amount
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
needed = max(0, amount - remaining)
|
||||
retry_after = int(needed / refill_rate) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=int(remaining),
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(
|
||||
self,
|
||||
key: str,
|
||||
capacity: int,
|
||||
refill_rate: float,
|
||||
amount: int,
|
||||
) -> Tuple[bool, int]:
|
||||
result = await self._consume_script(
|
||||
keys=[self._redis_key(key)],
|
||||
args=[time.time(), capacity, refill_rate, amount],
|
||||
)
|
||||
allowed = bool(result[0])
|
||||
remaining = int(float(result[1]))
|
||||
return allowed, remaining
|
||||
|
||||
async def reset(self, key: str):
|
||||
await self.redis.delete(self._redis_key(key))
|
||||
|
||||
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
|
||||
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
|
||||
tokens = data[0]
|
||||
timestamp = data[1]
|
||||
return {
|
||||
"strategy": "token_bucket",
|
||||
"key": key,
|
||||
"capacity": capacity,
|
||||
"remaining": float(tokens) if tokens else capacity,
|
||||
"refill_rate": refill_rate,
|
||||
"last_refill": float(timestamp) if timestamp else time.time(),
|
||||
"backend": "redis",
|
||||
}
|
||||
9
src/plugins/token/__init__.py
Normal file
9
src/plugins/token/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Token计数插件
|
||||
"""
|
||||
|
||||
from .base import TokenCounterPlugin, TokenUsage
|
||||
from .claude_counter import ClaudeTokenCounterPlugin
|
||||
from .tiktoken_counter import TiktokenCounterPlugin
|
||||
|
||||
__all__ = ["TokenCounterPlugin", "TokenUsage", "TiktokenCounterPlugin", "ClaudeTokenCounterPlugin"]
|
||||
170
src/plugins/token/base.py
Normal file
170
src/plugins/token/base.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Token计数插件基类
|
||||
定义Token计数的接口
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from src.plugins.common import BasePlugin
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""令牌使用情况"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cache_read_tokens: int = 0 # Claude缓存读取
|
||||
cache_write_tokens: int = 0 # Claude缓存写入
|
||||
reasoning_tokens: int = 0 # OpenAI o1推理令牌
|
||||
|
||||
def __add__(self, other: "TokenUsage") -> "TokenUsage":
|
||||
"""令牌使用相加"""
|
||||
return TokenUsage(
|
||||
input_tokens=self.input_tokens + other.input_tokens,
|
||||
output_tokens=self.output_tokens + other.output_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
}
|
||||
|
||||
|
||||
class TokenCounterPlugin(BasePlugin):
|
||||
"""
|
||||
Token计数插件基类
|
||||
支持不同模型的Token计数
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
|
||||
# 调用父类初始化,设置metadata
|
||||
super().__init__(
|
||||
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
|
||||
)
|
||||
|
||||
self.supported_models = self.config.get("supported_models", [])
|
||||
self.default_model = self.config.get("default_model")
|
||||
|
||||
@abstractmethod
|
||||
def supports_model(self, model: str) -> bool:
|
||||
"""检查是否支持指定模型"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
||||
"""计算文本的Token数量"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def count_messages(
|
||||
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
||||
) -> int:
|
||||
"""计算消息列表的Token数量"""
|
||||
pass
|
||||
|
||||
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
|
||||
"""计算请求的Token数量"""
|
||||
model = model or request.get("model") or self.default_model
|
||||
messages = request.get("messages", [])
|
||||
return await self.count_messages(messages, model)
|
||||
|
||||
async def count_response(
|
||||
self, response: Dict[str, Any], model: Optional[str] = None
|
||||
) -> TokenUsage:
|
||||
"""从响应中提取Token使用情况"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
# OpenAI格式
|
||||
if "prompt_tokens" in usage:
|
||||
return TokenUsage(
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
|
||||
"reasoning_tokens", 0
|
||||
),
|
||||
)
|
||||
|
||||
# Claude格式
|
||||
elif "input_tokens" in usage:
|
||||
return TokenUsage(
|
||||
input_tokens=usage.get("input_tokens", 0),
|
||||
output_tokens=usage.get("output_tokens", 0),
|
||||
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||||
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
||||
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
||||
)
|
||||
|
||||
return TokenUsage()
|
||||
|
||||
async def estimate_cost(
|
||||
self, usage: TokenUsage, model: str, provider: Optional[str] = None
|
||||
) -> Dict[str, float]:
|
||||
"""估算使用成本"""
|
||||
# 默认价格表(每1M tokens的价格)
|
||||
pricing = self.config.get("pricing", {})
|
||||
|
||||
# 获取模型价格
|
||||
model_pricing = pricing.get(model, {})
|
||||
if not model_pricing:
|
||||
# 尝试使用前缀匹配
|
||||
for model_prefix, price_info in pricing.items():
|
||||
if model.startswith(model_prefix):
|
||||
model_pricing = price_info
|
||||
break
|
||||
|
||||
if not model_pricing:
|
||||
return {"error": "No pricing information available"}
|
||||
|
||||
# 计算成本
|
||||
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
|
||||
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
|
||||
|
||||
# 缓存成本(Claude特有)
|
||||
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
|
||||
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
|
||||
"cache_write", 0
|
||||
)
|
||||
|
||||
# 推理成本(OpenAI o1特有)
|
||||
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
|
||||
|
||||
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
|
||||
|
||||
return {
|
||||
"input_cost": round(input_cost, 6),
|
||||
"output_cost": round(output_cost, 6),
|
||||
"cache_read_cost": round(cache_read_cost, 6),
|
||||
"cache_write_cost": round(cache_write_cost, 6),
|
||||
"reasoning_cost": round(reasoning_cost, 6),
|
||||
"total_cost": round(total_cost, 6),
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
pass
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"type": self.name,
|
||||
"enabled": self.enabled,
|
||||
"supported_models": self.supported_models,
|
||||
"default_model": self.default_model,
|
||||
}
|
||||
273
src/plugins/token/claude_counter.py
Normal file
273
src/plugins/token/claude_counter.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Claude Token计数插件
|
||||
专门为Claude模型设计的Token计数器
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import TokenCounterPlugin, TokenUsage
|
||||
|
||||
|
||||
class ClaudeTokenCounterPlugin(TokenCounterPlugin):
|
||||
"""
|
||||
Claude专用Token计数插件
|
||||
使用简化的估算方法
|
||||
"""
|
||||
|
||||
# Claude模型信息
|
||||
CLAUDE_MODELS = {
|
||||
"claude-3-5-sonnet-20241022": {
|
||||
"max_tokens": 200000,
|
||||
"max_output": 8192,
|
||||
"chars_per_token": 3.5, # 平均字符/token比例
|
||||
},
|
||||
"claude-3-5-haiku-20241022": {
|
||||
"max_tokens": 200000,
|
||||
"max_output": 8192,
|
||||
"chars_per_token": 3.5,
|
||||
},
|
||||
"claude-3-opus-20240229": {
|
||||
"max_tokens": 200000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 3.5,
|
||||
},
|
||||
"claude-3-sonnet-20240229": {
|
||||
"max_tokens": 200000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 3.5,
|
||||
},
|
||||
"claude-3-haiku-20240307": {
|
||||
"max_tokens": 200000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 3.5,
|
||||
},
|
||||
# 旧版模型
|
||||
"claude-2.1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 4,
|
||||
},
|
||||
"claude-2.0": {
|
||||
"max_tokens": 100000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 4,
|
||||
},
|
||||
"claude-instant-1.2": {
|
||||
"max_tokens": 100000,
|
||||
"max_output": 4096,
|
||||
"chars_per_token": 4,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, name: str = "claude", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
|
||||
# 价格表(每1M tokens的价格 USD)
|
||||
default_pricing = {
|
||||
"claude-3-5-sonnet": {
|
||||
"input": 3,
|
||||
"output": 15,
|
||||
"cache_write": 3.75, # 缓存写入
|
||||
"cache_read": 0.30, # 缓存读取
|
||||
},
|
||||
"claude-3-5-haiku": {
|
||||
"input": 0.8,
|
||||
"output": 4,
|
||||
"cache_write": 1,
|
||||
"cache_read": 0.08,
|
||||
},
|
||||
"claude-3-opus": {
|
||||
"input": 15,
|
||||
"output": 75,
|
||||
"cache_write": 18.75,
|
||||
"cache_read": 1.50,
|
||||
},
|
||||
"claude-3-sonnet": {
|
||||
"input": 3,
|
||||
"output": 15,
|
||||
"cache_write": 3.75,
|
||||
"cache_read": 0.30,
|
||||
},
|
||||
"claude-3-haiku": {
|
||||
"input": 0.25,
|
||||
"output": 1.25,
|
||||
"cache_write": 0.30,
|
||||
"cache_read": 0.03,
|
||||
},
|
||||
"claude-2.1": {
|
||||
"input": 8,
|
||||
"output": 24,
|
||||
},
|
||||
"claude-2.0": {
|
||||
"input": 8,
|
||||
"output": 24,
|
||||
},
|
||||
"claude-instant": {
|
||||
"input": 0.8,
|
||||
"output": 2.4,
|
||||
},
|
||||
}
|
||||
self.config["pricing"] = (
|
||||
config.get("pricing", default_pricing) if config else default_pricing
|
||||
)
|
||||
|
||||
def supports_model(self, model: str) -> bool:
|
||||
"""检查是否支持指定模型"""
|
||||
# 支持所有Claude模型
|
||||
return "claude" in model.lower()
|
||||
|
||||
def _estimate_tokens_from_text(self, text: str, model: str) -> int:
|
||||
"""从文本估算Token数量"""
|
||||
# 获取模型信息
|
||||
model_info = None
|
||||
for model_name, info in self.CLAUDE_MODELS.items():
|
||||
if model.startswith(model_name.split("-20")[0]): # 匹配基本名称
|
||||
model_info = info
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
# 默认值
|
||||
model_info = {"chars_per_token": 3.5}
|
||||
|
||||
# 基本估算
|
||||
chars_per_token = model_info["chars_per_token"]
|
||||
|
||||
# 考虑不同语言的特点
|
||||
# 检测是否包含中文/日文/韩文
|
||||
cjk_pattern = re.compile(r"[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]")
|
||||
cjk_count = len(cjk_pattern.findall(text))
|
||||
|
||||
if cjk_count > len(text) * 0.3: # 超过30%是CJK字符
|
||||
# CJK字符通常每个字符1-2个token
|
||||
return int(len(text) / 1.5)
|
||||
else:
|
||||
# 英文和其他语言
|
||||
# 考虑空格和标点
|
||||
word_count = len(text.split())
|
||||
# 平均每个单词1.3个token
|
||||
token_by_words = int(word_count * 1.3)
|
||||
# 平均每个字符chars_per_token
|
||||
token_by_chars = int(len(text) / chars_per_token)
|
||||
# 取两者的平均
|
||||
return (token_by_words + token_by_chars) // 2
|
||||
|
||||
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
||||
"""计算文本的Token数量"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
model = model or self.default_model or "claude-3-5-sonnet-20241022"
|
||||
return self._estimate_tokens_from_text(text, model)
|
||||
|
||||
async def count_messages(
|
||||
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
||||
) -> int:
|
||||
"""计算消息列表的Token数量"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
model = model or self.default_model or "claude-3-5-sonnet-20241022"
|
||||
total_tokens = 0
|
||||
|
||||
for message in messages:
|
||||
# 角色token(约3 tokens)
|
||||
total_tokens += 3
|
||||
|
||||
# 内容token
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
total_tokens += self._estimate_tokens_from_text(content, model)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
total_tokens += self._estimate_tokens_from_text(text, model)
|
||||
elif item.get("type") == "image":
|
||||
# Claude图像处理
|
||||
# 基础: 1,600 tokens
|
||||
# 每个256x256的tile: 280 tokens
|
||||
# 简化估算
|
||||
total_tokens += 2000 # 平均估算
|
||||
elif item.get("type") == "tool_use":
|
||||
# 工具使用
|
||||
tool_name = item.get("name", "")
|
||||
tool_input = item.get("input", {})
|
||||
total_tokens += self._estimate_tokens_from_text(tool_name, model)
|
||||
total_tokens += self._estimate_tokens_from_text(
|
||||
json.dumps(tool_input), model
|
||||
)
|
||||
elif item.get("type") == "tool_result":
|
||||
# 工具结果
|
||||
tool_content = item.get("content", "")
|
||||
if isinstance(tool_content, str):
|
||||
total_tokens += self._estimate_tokens_from_text(tool_content, model)
|
||||
|
||||
# 添加系统提示的token(如果有)
|
||||
if messages and messages[0].get("role") == "system":
|
||||
# 系统提示通常会有额外的开销
|
||||
total_tokens += 10
|
||||
|
||||
return total_tokens
|
||||
|
||||
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
|
||||
"""计算请求的Token数量"""
|
||||
model = model or request.get("model") or self.default_model
|
||||
messages = request.get("messages", [])
|
||||
total = await self.count_messages(messages, model)
|
||||
|
||||
# 考虑系统提示
|
||||
system = request.get("system")
|
||||
if system:
|
||||
total += self._estimate_tokens_from_text(system, model)
|
||||
total += 5 # 系统提示的额外开销
|
||||
|
||||
return total
|
||||
|
||||
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
info = {"model": model, "supported": self.supports_model(model)}
|
||||
|
||||
if self.supports_model(model):
|
||||
# 查找匹配的模型信息
|
||||
model_info = None
|
||||
for model_name, m_info in self.CLAUDE_MODELS.items():
|
||||
if model.startswith(model_name.split("-20")[0]):
|
||||
model_info = m_info
|
||||
info["model_name"] = model_name
|
||||
break
|
||||
|
||||
if model_info:
|
||||
info.update(
|
||||
{
|
||||
"max_tokens": model_info["max_tokens"],
|
||||
"max_output": model_info["max_output"],
|
||||
"chars_per_token": model_info["chars_per_token"],
|
||||
"supports_vision": "claude-3" in model,
|
||||
"supports_tools": "claude-3" in model,
|
||||
"supports_cache": "claude-3" in model,
|
||||
}
|
||||
)
|
||||
|
||||
# 添加价格信息
|
||||
pricing = self.config.get("pricing", {})
|
||||
for price_key in pricing:
|
||||
if model.startswith(price_key):
|
||||
info["pricing"] = pricing[price_key]
|
||||
break
|
||||
|
||||
return info
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
stats = await super().get_stats()
|
||||
stats.update(
|
||||
{
|
||||
"estimation_method": "character_based",
|
||||
"supported_models_count": len(self.CLAUDE_MODELS),
|
||||
}
|
||||
)
|
||||
return stats
|
||||
269
src/plugins/token/tiktoken_counter.py
Normal file
269
src/plugins/token/tiktoken_counter.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Tiktoken Token计数插件
|
||||
支持OpenAI和其他使用tiktoken的模型
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from .base import TokenCounterPlugin, TokenUsage
|
||||
|
||||
# 尝试导入tiktoken
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
tiktoken = None
|
||||
|
||||
|
||||
class TiktokenCounterPlugin(TokenCounterPlugin):
|
||||
"""
|
||||
使用tiktoken库计算Token数量
|
||||
支持OpenAI模型和其他兼容模型
|
||||
"""
|
||||
|
||||
# 模型编码映射
|
||||
MODEL_ENCODINGS = {
|
||||
# GPT-4 系列
|
||||
"gpt-4": "cl100k_base",
|
||||
"gpt-4-32k": "cl100k_base",
|
||||
"gpt-4-turbo": "cl100k_base",
|
||||
"gpt-4-turbo-preview": "cl100k_base",
|
||||
"gpt-4o": "o200k_base",
|
||||
"gpt-4o-mini": "o200k_base",
|
||||
# GPT-3.5 系列
|
||||
"gpt-3.5-turbo": "cl100k_base",
|
||||
"gpt-3.5-turbo-16k": "cl100k_base",
|
||||
# 旧模型
|
||||
"text-davinci-003": "p50k_base",
|
||||
"text-davinci-002": "p50k_base",
|
||||
"code-davinci-002": "p50k_base",
|
||||
# Embeddings
|
||||
"text-embedding-ada-002": "cl100k_base",
|
||||
"text-embedding-3-small": "cl100k_base",
|
||||
"text-embedding-3-large": "cl100k_base",
|
||||
}
|
||||
|
||||
# 每个消息的额外Token数
|
||||
MESSAGE_OVERHEAD = {
|
||||
"gpt-3.5-turbo": 4, # 每条消息
|
||||
"gpt-4": 3,
|
||||
"gpt-4-turbo": 3,
|
||||
"gpt-4o": 3,
|
||||
"gpt-4o-mini": 3,
|
||||
}
|
||||
|
||||
def __init__(self, name: str = "tiktoken", config: Dict[str, Any] = None):
|
||||
super().__init__(name, config)
|
||||
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
self.enabled = False
|
||||
logger.warning("tiktoken not installed, plugin disabled")
|
||||
return
|
||||
|
||||
# 缓存编码器
|
||||
self._encoders = {}
|
||||
|
||||
# 价格表(每1M tokens的价格 USD)
|
||||
default_pricing = {
|
||||
"gpt-4o": {"input": 2.5, "output": 10},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
|
||||
"gpt-4-turbo": {"input": 10, "output": 30},
|
||||
"gpt-4": {"input": 30, "output": 60},
|
||||
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
|
||||
"o1-preview": {"input": 15, "output": 60, "reasoning": 60},
|
||||
"o1-mini": {"input": 3, "output": 12, "reasoning": 12},
|
||||
}
|
||||
self.config["pricing"] = (
|
||||
config.get("pricing", default_pricing) if config else default_pricing
|
||||
)
|
||||
|
||||
def _get_encoder(self, model: str) -> Any:
|
||||
"""获取模型的编码器"""
|
||||
if model in self._encoders:
|
||||
return self._encoders[model]
|
||||
|
||||
# 获取编码名称
|
||||
encoding_name = None
|
||||
|
||||
# 完全匹配
|
||||
if model in self.MODEL_ENCODINGS:
|
||||
encoding_name = self.MODEL_ENCODINGS[model]
|
||||
else:
|
||||
# 前缀匹配
|
||||
for model_prefix, enc_name in self.MODEL_ENCODINGS.items():
|
||||
if model.startswith(model_prefix):
|
||||
encoding_name = enc_name
|
||||
break
|
||||
|
||||
# 如果找不到,尝试使用模型名称
|
||||
if not encoding_name:
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model)
|
||||
self._encoders[model] = encoder
|
||||
return encoder
|
||||
except:
|
||||
# 默认使用cl100k_base
|
||||
encoding_name = "cl100k_base"
|
||||
|
||||
# 创建编码器
|
||||
encoder = tiktoken.get_encoding(encoding_name)
|
||||
self._encoders[model] = encoder
|
||||
return encoder
|
||||
|
||||
def supports_model(self, model: str) -> bool:
|
||||
"""检查是否支持指定模型"""
|
||||
# 支持所有OpenAI模型和一些兼容模型
|
||||
openai_models = ["gpt-4", "gpt-3.5", "text-davinci", "text-embedding", "code-davinci", "o1"]
|
||||
return any(model.startswith(prefix) for prefix in openai_models)
|
||||
|
||||
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
||||
"""计算文本的Token数量"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
model = model or self.default_model or "gpt-3.5-turbo"
|
||||
encoder = self._get_encoder(model)
|
||||
|
||||
try:
|
||||
tokens = encoder.encode(text)
|
||||
return len(tokens)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error counting tokens: {e}")
|
||||
# 简单估算: 平均每个字符0.75个token
|
||||
return int(len(text) * 0.75)
|
||||
|
||||
async def count_messages(
|
||||
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
||||
) -> int:
|
||||
"""计算消息列表的Token数量"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
model = model or self.default_model or "gpt-3.5-turbo"
|
||||
encoder = self._get_encoder(model)
|
||||
|
||||
# 获取每条消息的额外token数
|
||||
msg_overhead = self.MESSAGE_OVERHEAD.get(model, 3)
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
for message in messages:
|
||||
# 每条消息的基本token
|
||||
total_tokens += msg_overhead
|
||||
|
||||
# 角色token
|
||||
role = message.get("role", "")
|
||||
if role:
|
||||
total_tokens += len(encoder.encode(role))
|
||||
|
||||
# 内容token
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
total_tokens += len(encoder.encode(content))
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
total_tokens += len(encoder.encode(text))
|
||||
elif item.get("type") == "image_url":
|
||||
# 图像的token计算更复杂,这里简化处理
|
||||
# 低分辨率: 85 tokens, 高分辨率: 170 tokens
|
||||
detail = item.get("image_url", {}).get("detail", "auto")
|
||||
total_tokens += 170 if detail == "high" else 85
|
||||
|
||||
# 名称token
|
||||
name = message.get("name")
|
||||
if name:
|
||||
total_tokens += len(encoder.encode(name)) - 1 # name会减去1个token
|
||||
|
||||
# 工具调用
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# 工具ID
|
||||
if "id" in tool_call:
|
||||
total_tokens += len(encoder.encode(tool_call["id"]))
|
||||
|
||||
# 函数信息
|
||||
function = tool_call.get("function", {})
|
||||
if "name" in function:
|
||||
total_tokens += len(encoder.encode(function["name"]))
|
||||
if "arguments" in function:
|
||||
total_tokens += len(encoder.encode(function["arguments"]))
|
||||
|
||||
# 添加固定的结束标记
|
||||
total_tokens += 3
|
||||
|
||||
return total_tokens
|
||||
|
||||
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
info = {"model": model, "supported": self.supports_model(model)}
|
||||
|
||||
if self.supports_model(model):
|
||||
# 获取编码信息
|
||||
encoder = self._get_encoder(model)
|
||||
encoding_name = None
|
||||
|
||||
# 找到编码名称
|
||||
for m, enc in self.MODEL_ENCODINGS.items():
|
||||
if model.startswith(m):
|
||||
encoding_name = enc
|
||||
break
|
||||
|
||||
info.update(
|
||||
{
|
||||
"encoding": encoding_name or "unknown",
|
||||
"vocab_size": encoder.n_vocab if hasattr(encoder, "n_vocab") else None,
|
||||
"max_tokens": self._get_max_tokens(model),
|
||||
"message_overhead": self.MESSAGE_OVERHEAD.get(model, 3),
|
||||
}
|
||||
)
|
||||
|
||||
# 添加价格信息
|
||||
pricing = self.config.get("pricing", {})
|
||||
if model in pricing:
|
||||
info["pricing"] = pricing[model]
|
||||
|
||||
return info
|
||||
|
||||
def _get_max_tokens(self, model: str) -> int:
|
||||
"""获取模型的最大token数"""
|
||||
max_tokens_map = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
}
|
||||
|
||||
# 完全匹配
|
||||
if model in max_tokens_map:
|
||||
return max_tokens_map[model]
|
||||
|
||||
# 前缀匹配
|
||||
for model_prefix, max_tokens in max_tokens_map.items():
|
||||
if model.startswith(model_prefix):
|
||||
return max_tokens
|
||||
|
||||
# 默认值
|
||||
return 4096
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
stats = await super().get_stats()
|
||||
stats.update(
|
||||
{"encoders_cached": len(self._encoders), "tiktoken_available": TIKTOKEN_AVAILABLE}
|
||||
)
|
||||
return stats
|
||||
Reference in New Issue
Block a user