diff --git a/src/api/admin/monitoring/cache.py b/src/api/admin/monitoring/cache.py index 43dda40..79b3044 100644 --- a/src/api/admin/monitoring/cache.py +++ b/src/api/admin/monitoring/cache.py @@ -869,3 +869,310 @@ class AdminCacheConfigAdapter(AdminApiAdapter): dynamic_reservation_enabled=True, ) return response + + +# ==================== 模型映射缓存管理 ==================== + + +@router.get("/model-mapping/stats") +async def get_model_mapping_cache_stats( + request: Request, + db: Session = Depends(get_db), +): + """ + 获取模型映射缓存统计信息 + + 返回: + - 缓存键数量 + - 缓存 TTL 配置 + - 各类型缓存数量 + """ + adapter = AdminModelMappingCacheStatsAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.delete("/model-mapping") +async def clear_all_model_mapping_cache( + request: Request, + db: Session = Depends(get_db), +): + """ + 清除所有模型映射缓存 + + 警告: 这会影响所有模型解析,请谨慎使用 + """ + adapter = AdminClearAllModelMappingCacheAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.delete("/model-mapping/{model_name}") +async def clear_model_mapping_cache_by_name( + model_name: str, + request: Request, + db: Session = Depends(get_db), +): + """ + 清除指定模型名称的映射缓存 + + 参数: + - model_name: 模型名称(可以是 GlobalModel.name 或映射名称) + """ + adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +class AdminModelMappingCacheStatsAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + import json + + from src.clients.redis_client import get_redis_client + from src.config.constants import CacheTTL + from src.models.database import GlobalModel, Model, Provider + + db = context.db + + try: + redis = await get_redis_client(require_redis=False) + if not redis: + return { + "status": "ok", + "data": { + "available": False, + "message": "Redis 未启用,模型映射缓存不可用", + }, + } + + # 统计各类型缓存键数量 + model_id_keys = [] + global_model_id_keys = [] + global_model_name_keys = [] + global_model_resolve_keys = [] + provider_global_keys = [] + + # 扫描所有模型相关的缓存键 + async for key in redis.scan_iter(match="model:*", count=100): + key_str = key.decode() if isinstance(key, bytes) else key + if key_str.startswith("model:id:"): + model_id_keys.append(key_str) + elif key_str.startswith("model:provider_global:"): + provider_global_keys.append(key_str) + + async for key in redis.scan_iter(match="global_model:*", count=100): + key_str = key.decode() if isinstance(key, bytes) else key + if key_str.startswith("global_model:id:"): + global_model_id_keys.append(key_str) + elif key_str.startswith("global_model:name:"): + global_model_name_keys.append(key_str) + elif key_str.startswith("global_model:resolve:"): + global_model_resolve_keys.append(key_str) + + total_keys = ( + len(model_id_keys) + + len(global_model_id_keys) + + len(global_model_name_keys) + + len(global_model_resolve_keys) + + len(provider_global_keys) + ) + + # 解析缓存内容,构建映射列表 + mappings = [] + unmapped_entries = [] + + for key in global_model_resolve_keys[:100]: # 最多处理 100 个 + mapping_name = key.replace("global_model:resolve:", "") + try: + cached_value = await redis.get(key) + ttl = await redis.ttl(key) + + if cached_value: + cached_str = ( + cached_value.decode() + if isinstance(cached_value, bytes) + else cached_value + ) + + if cached_str == "NOT_FOUND": + unmapped_entries.append({ + "mapping_name": mapping_name, + "status": "not_found", + "ttl": ttl if ttl > 0 else None, + }) + else: + try: + cached_data = json.loads(cached_str) + global_model_id = cached_data.get("id") + global_model_name = cached_data.get("name") + global_model_display_name = cached_data.get("display_name") + + # 跳过 mapping_name == global_model_name 的情况(直接匹配,不是映射) + if mapping_name == global_model_name: + continue + + # 查询哪些 Provider 配置了这个映射名称 + provider_names = [] + if global_model_id: + models = ( + db.query(Model, Provider) + .join(Provider, Model.provider_id == Provider.id) + .filter( + Model.global_model_id == global_model_id, + Model.is_active, + Provider.is_active, + ) + .all() + ) + # 只显示配置了该映射名称的 Provider + for model, provider in models: + # 检查是否是主模型名称 + if model.provider_model_name == mapping_name: + provider_names.append( + provider.display_name or provider.name + ) + continue + # 检查是否在别名列表中 + if model.provider_model_aliases: + alias_names = [ + a.get("name") + for a in model.provider_model_aliases + if isinstance(a, dict) + ] + if mapping_name in alias_names: + provider_names.append( + provider.display_name or provider.name + ) + provider_names = sorted(list(set(provider_names))) + + mappings.append({ + "mapping_name": mapping_name, + "global_model_name": global_model_name, + "global_model_display_name": global_model_display_name, + "providers": provider_names, + "ttl": ttl if ttl > 0 else None, + }) + + except json.JSONDecodeError: + unmapped_entries.append({ + "mapping_name": mapping_name, + "status": "invalid", + "ttl": ttl if ttl > 0 else None, + }) + except Exception as e: + logger.warning(f"解析缓存键 {key} 失败: {e}") + unmapped_entries.append({ + "mapping_name": mapping_name, + "status": "error", + "ttl": None, + }) + + # 按 mapping_name 排序 + mappings.sort(key=lambda x: x["mapping_name"]) + + response_data = { + "available": True, + "ttl_seconds": CacheTTL.MODEL, + "total_keys": total_keys, + "breakdown": { + "model_by_id": len(model_id_keys), + "model_by_provider_global": len(provider_global_keys), + "global_model_by_id": len(global_model_id_keys), + "global_model_by_name": len(global_model_name_keys), + "global_model_resolve": len(global_model_resolve_keys), + }, + "mappings": mappings, + "unmapped": unmapped_entries if unmapped_entries else None, + } + + context.add_audit_metadata( + action="model_mapping_cache_stats", + total_keys=total_keys, + ) + return {"status": "ok", "data": response_data} + + except Exception as exc: + logger.exception(f"获取模型映射缓存统计失败: {exc}") + raise HTTPException(status_code=500, detail=f"获取统计失败: {exc}") + + +class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + from src.clients.redis_client import get_redis_client + + try: + redis = await get_redis_client(require_redis=False) + if not redis: + raise HTTPException(status_code=503, detail="Redis 未启用") + + deleted_count = 0 + + # 删除所有模型相关的缓存键 + keys_to_delete = [] + async for key in redis.scan_iter(match="model:*", count=100): + keys_to_delete.append(key) + async for key in redis.scan_iter(match="global_model:*", count=100): + keys_to_delete.append(key) + + if keys_to_delete: + deleted_count = await redis.delete(*keys_to_delete) + + logger.warning(f"已清除所有模型映射缓存(管理员操作): {deleted_count} 个键") + context.add_audit_metadata( + action="model_mapping_cache_clear_all", + deleted_count=deleted_count, + ) + return { + "status": "ok", + "message": f"已清除所有模型映射缓存", + "deleted_count": deleted_count, + } + + except HTTPException: + raise + except Exception as exc: + logger.exception(f"清除模型映射缓存失败: {exc}") + raise HTTPException(status_code=500, detail=f"清除失败: {exc}") + + +@dataclass +class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter): + model_name: str + + async def handle(self, context): # type: ignore[override] + from src.clients.redis_client import get_redis_client + + try: + redis = await get_redis_client(require_redis=False) + if not redis: + raise HTTPException(status_code=503, detail="Redis 未启用") + + deleted_keys = [] + + # 清除 resolve 缓存 + resolve_key = f"global_model:resolve:{self.model_name}" + if await redis.exists(resolve_key): + await redis.delete(resolve_key) + deleted_keys.append(resolve_key) + + # 清除 name 缓存 + name_key = f"global_model:name:{self.model_name}" + if await redis.exists(name_key): + await redis.delete(name_key) + deleted_keys.append(name_key) + + logger.info(f"已清除模型映射缓存: model_name={self.model_name}, 删除键={deleted_keys}") + context.add_audit_metadata( + action="model_mapping_cache_clear_by_name", + model_name=self.model_name, + deleted_keys=deleted_keys, + ) + return { + "status": "ok", + "message": f"已清除模型 {self.model_name} 的映射缓存", + "model_name": self.model_name, + "deleted_keys": deleted_keys, + } + + except HTTPException: + raise + except Exception as exc: + logger.exception(f"清除模型映射缓存失败: {exc}") + raise HTTPException(status_code=500, detail=f"清除失败: {exc}")