feat(api): add model mapping cache management endpoints

This commit is contained in:
fawney19
2025-12-15 20:39:51 +08:00
parent 34d480910a
commit f2cd96c34c

View File

@@ -869,3 +869,310 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
dynamic_reservation_enabled=True, dynamic_reservation_enabled=True,
) )
return response return response
# ==================== 模型映射缓存管理 ====================
@router.get("/model-mapping/stats")
async def get_model_mapping_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
"""
获取模型映射缓存统计信息
返回:
- 缓存键数量
- 缓存 TTL 配置
- 各类型缓存数量
"""
adapter = AdminModelMappingCacheStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping")
async def clear_all_model_mapping_cache(
request: Request,
db: Session = Depends(get_db),
):
"""
清除所有模型映射缓存
警告: 这会影响所有模型解析,请谨慎使用
"""
adapter = AdminClearAllModelMappingCacheAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping/{model_name}")
async def clear_model_mapping_cache_by_name(
model_name: str,
request: Request,
db: Session = Depends(get_db),
):
"""
清除指定模型名称的映射缓存
参数:
- model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
"""
adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
import json
from src.clients.redis_client import get_redis_client
from src.config.constants import CacheTTL
from src.models.database import GlobalModel, Model, Provider
db = context.db
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return {
"status": "ok",
"data": {
"available": False,
"message": "Redis 未启用,模型映射缓存不可用",
},
}
# 统计各类型缓存键数量
model_id_keys = []
global_model_id_keys = []
global_model_name_keys = []
global_model_resolve_keys = []
provider_global_keys = []
# 扫描所有模型相关的缓存键
async for key in redis.scan_iter(match="model:*", count=100):
key_str = key.decode() if isinstance(key, bytes) else key
if key_str.startswith("model:id:"):
model_id_keys.append(key_str)
elif key_str.startswith("model:provider_global:"):
provider_global_keys.append(key_str)
async for key in redis.scan_iter(match="global_model:*", count=100):
key_str = key.decode() if isinstance(key, bytes) else key
if key_str.startswith("global_model:id:"):
global_model_id_keys.append(key_str)
elif key_str.startswith("global_model:name:"):
global_model_name_keys.append(key_str)
elif key_str.startswith("global_model:resolve:"):
global_model_resolve_keys.append(key_str)
total_keys = (
len(model_id_keys)
+ len(global_model_id_keys)
+ len(global_model_name_keys)
+ len(global_model_resolve_keys)
+ len(provider_global_keys)
)
# 解析缓存内容,构建映射列表
mappings = []
unmapped_entries = []
for key in global_model_resolve_keys[:100]: # 最多处理 100 个
mapping_name = key.replace("global_model:resolve:", "")
try:
cached_value = await redis.get(key)
ttl = await redis.ttl(key)
if cached_value:
cached_str = (
cached_value.decode()
if isinstance(cached_value, bytes)
else cached_value
)
if cached_str == "NOT_FOUND":
unmapped_entries.append({
"mapping_name": mapping_name,
"status": "not_found",
"ttl": ttl if ttl > 0 else None,
})
else:
try:
cached_data = json.loads(cached_str)
global_model_id = cached_data.get("id")
global_model_name = cached_data.get("name")
global_model_display_name = cached_data.get("display_name")
# 跳过 mapping_name == global_model_name 的情况(直接匹配,不是映射)
if mapping_name == global_model_name:
continue
# 查询哪些 Provider 配置了这个映射名称
provider_names = []
if global_model_id:
models = (
db.query(Model, Provider)
.join(Provider, Model.provider_id == Provider.id)
.filter(
Model.global_model_id == global_model_id,
Model.is_active,
Provider.is_active,
)
.all()
)
# 只显示配置了该映射名称的 Provider
for model, provider in models:
# 检查是否是主模型名称
if model.provider_model_name == mapping_name:
provider_names.append(
provider.display_name or provider.name
)
continue
# 检查是否在别名列表中
if model.provider_model_aliases:
alias_names = [
a.get("name")
for a in model.provider_model_aliases
if isinstance(a, dict)
]
if mapping_name in alias_names:
provider_names.append(
provider.display_name or provider.name
)
provider_names = sorted(list(set(provider_names)))
mappings.append({
"mapping_name": mapping_name,
"global_model_name": global_model_name,
"global_model_display_name": global_model_display_name,
"providers": provider_names,
"ttl": ttl if ttl > 0 else None,
})
except json.JSONDecodeError:
unmapped_entries.append({
"mapping_name": mapping_name,
"status": "invalid",
"ttl": ttl if ttl > 0 else None,
})
except Exception as e:
logger.warning(f"解析缓存键 {key} 失败: {e}")
unmapped_entries.append({
"mapping_name": mapping_name,
"status": "error",
"ttl": None,
})
# 按 mapping_name 排序
mappings.sort(key=lambda x: x["mapping_name"])
response_data = {
"available": True,
"ttl_seconds": CacheTTL.MODEL,
"total_keys": total_keys,
"breakdown": {
"model_by_id": len(model_id_keys),
"model_by_provider_global": len(provider_global_keys),
"global_model_by_id": len(global_model_id_keys),
"global_model_by_name": len(global_model_name_keys),
"global_model_resolve": len(global_model_resolve_keys),
},
"mappings": mappings,
"unmapped": unmapped_entries if unmapped_entries else None,
}
context.add_audit_metadata(
action="model_mapping_cache_stats",
total_keys=total_keys,
)
return {"status": "ok", "data": response_data}
except Exception as exc:
logger.exception(f"获取模型映射缓存统计失败: {exc}")
raise HTTPException(status_code=500, detail=f"获取统计失败: {exc}")
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_count = 0
# 删除所有模型相关的缓存键
keys_to_delete = []
async for key in redis.scan_iter(match="model:*", count=100):
keys_to_delete.append(key)
async for key in redis.scan_iter(match="global_model:*", count=100):
keys_to_delete.append(key)
if keys_to_delete:
deleted_count = await redis.delete(*keys_to_delete)
logger.warning(f"已清除所有模型映射缓存(管理员操作): {deleted_count} 个键")
context.add_audit_metadata(
action="model_mapping_cache_clear_all",
deleted_count=deleted_count,
)
return {
"status": "ok",
"message": f"已清除所有模型映射缓存",
"deleted_count": deleted_count,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
model_name: str
async def handle(self, context): # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_keys = []
# 清除 resolve 缓存
resolve_key = f"global_model:resolve:{self.model_name}"
if await redis.exists(resolve_key):
await redis.delete(resolve_key)
deleted_keys.append(resolve_key)
# 清除 name 缓存
name_key = f"global_model:name:{self.model_name}"
if await redis.exists(name_key):
await redis.delete(name_key)
deleted_keys.append(name_key)
logger.info(f"已清除模型映射缓存: model_name={self.model_name}, 删除键={deleted_keys}")
context.add_audit_metadata(
action="model_mapping_cache_clear_by_name",
model_name=self.model_name,
deleted_keys=deleted_keys,
)
return {
"status": "ok",
"message": f"已清除模型 {self.model_name} 的映射缓存",
"model_name": self.model_name,
"deleted_keys": deleted_keys,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")