feat(cache): enhance cache monitoring endpoints and handler integrations

This commit is contained in:
fawney19
2025-12-15 23:12:48 +08:00
parent 718f56ba75
commit cf67160821
4 changed files with 204 additions and 41 deletions

View File

@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.context import ApiRequestContext
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync
@@ -87,19 +88,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
# 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first()
if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id
# 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first()
if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id
# 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
return api_key.user_id
# 无法识别
@@ -111,7 +112,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
async def get_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取缓存亲和性统计信息
@@ -131,7 +132,7 @@ async def get_user_affinity(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
查询指定用户的所有缓存亲和性
@@ -157,7 +158,7 @@ async def list_affinities(
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
):
) -> Any:
"""
获取所有缓存亲和性列表,可选按关键词过滤
@@ -173,7 +174,7 @@ async def clear_user_cache(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear cache affinity for a specific user
@@ -188,7 +189,7 @@ async def clear_user_cache(
async def clear_all_cache(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear all cache affinities
@@ -203,7 +204,7 @@ async def clear_provider_cache(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear cache affinities for a specific provider
@@ -218,7 +219,7 @@ async def clear_provider_cache(
async def get_cache_config(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取缓存相关配置
@@ -234,7 +235,7 @@ async def get_cache_config(
async def get_cache_metrics(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
"""
@@ -246,7 +247,7 @@ async def get_cache_metrics(
class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
@@ -266,7 +267,7 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
@@ -391,7 +392,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
user_id = resolve_user_identifier(db, self.user_identifier)
@@ -472,7 +473,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
redis_client = get_redis_client_sync()
if not redis_client:
@@ -682,7 +683,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
@@ -786,7 +787,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
@@ -806,7 +807,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
@@ -829,7 +830,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
@@ -878,7 +879,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
async def get_model_mapping_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取模型映射缓存统计信息
@@ -895,7 +896,7 @@ async def get_model_mapping_cache_stats(
async def clear_all_model_mapping_cache(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
清除所有模型映射缓存
@@ -910,7 +911,7 @@ async def clear_model_mapping_cache_by_name(
model_name: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
清除指定模型名称的映射缓存
@@ -921,8 +922,28 @@ async def clear_model_mapping_cache_by_name(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
async def clear_provider_model_mapping_cache(
provider_id: str,
global_model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
清除指定 Provider 和 GlobalModel 的模型映射缓存
参数:
- provider_id: Provider ID
- global_model_id: GlobalModel ID
"""
adapter = AdminClearProviderModelMappingCacheAdapter(
provider_id=provider_id, global_model_id=global_model_id
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
import json
from src.clients.redis_client import get_redis_client
@@ -955,6 +976,8 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if key_str.startswith("model:id:"):
model_id_keys.append(key_str)
elif key_str.startswith("model:provider_global:"):
# 过滤掉 hits 统计键,只保留实际的缓存键
if not key_str.startswith("model:provider_global:hits:"):
provider_global_keys.append(key_str)
async for key in redis.scan_iter(match="global_model:*", count=100):
@@ -1067,6 +1090,85 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
# 按 mapping_name 排序
mappings.sort(key=lambda x: x["mapping_name"])
# 3. 解析 provider_global 缓存Provider 级别的模型解析缓存)
provider_model_mappings = []
# 预加载 Provider 和 GlobalModel 数据
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
for key in provider_global_keys[:100]: # 最多处理 100 个
# key 格式: model:provider_global:{provider_id}:{global_model_id}
try:
parts = key.replace("model:provider_global:", "").split(":")
if len(parts) != 2:
continue
provider_id, global_model_id = parts
cached_value = await redis.get(key)
ttl = await redis.ttl(key)
# 获取命中次数
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
hit_count_raw = await redis.get(hit_count_key)
hit_count = int(hit_count_raw) if hit_count_raw else 0
if cached_value:
cached_str = (
cached_value.decode()
if isinstance(cached_value, bytes)
else cached_value
)
try:
cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name")
provider_model_aliases = cached_data.get("provider_model_aliases", [])
# 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id)
global_model = global_model_map.get(global_model_id)
if provider and global_model:
# 提取别名名称
alias_names = []
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict) and alias_entry.get("name"):
alias_names.append(alias_entry["name"])
# provider_model_name 为空时跳过
if not provider_model_name:
continue
# 只显示有实际映射的条目:
# 1. 全局模型名 != Provider 模型名(模型名称映射)
# 2. 或者有别名配置
has_name_mapping = global_model.name != provider_model_name
has_aliases = len(alias_names) > 0
if has_name_mapping or has_aliases:
# 构建用于展示的别名列表
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
display_aliases = alias_names if alias_names else [global_model.name]
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,
"provider_model_name": provider_model_name,
"aliases": display_aliases,
"ttl": ttl if ttl > 0 else None,
"hit_count": hit_count,
})
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
# 按 provider_name + global_model_name 排序
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
response_data = {
"available": True,
"ttl_seconds": CacheTTL.MODEL,
@@ -1079,6 +1181,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
"global_model_resolve": len(global_model_resolve_keys),
},
"mappings": mappings,
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
"unmapped": unmapped_entries if unmapped_entries else None,
}
@@ -1094,7 +1197,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
@@ -1136,7 +1239,7 @@ class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
model_name: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
@@ -1176,3 +1279,55 @@ class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
provider_id: str
global_model_id: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_keys = []
# 清除 provider_global 缓存
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
if await redis.exists(provider_global_key):
await redis.delete(provider_global_key)
deleted_keys.append(provider_global_key)
# 清除对应的 hit_count 缓存
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
if await redis.exists(hit_count_key):
await redis.delete(hit_count_key)
deleted_keys.append(hit_count_key)
logger.info(
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
)
context.add_audit_metadata(
action="provider_model_mapping_cache_clear",
provider_id=self.provider_id,
global_model_id=self.global_model_id,
deleted_keys=deleted_keys,
)
return {
"status": "ok",
"message": "已清除 Provider 模型映射缓存",
"provider_id": self.provider_id,
"global_model_id": self.global_model_id,
"deleted_keys": deleted_keys,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")

View File

@@ -395,3 +395,24 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
Args:
message: 错误消息前缀
error: 异常对象
"""
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:
# 未知异常:完整堆栈
logger.exception(f"{message}: {error}")

View File

@@ -382,7 +382,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
except Exception as e:
logger.exception(f"流式请求失败: {e}")
self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise

View File

@@ -413,20 +413,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
)
except Exception as e:
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
else:
# 未知异常:完整堆栈
logger.exception(f"流式请求失败: {e}")
self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise