mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
feat(api): add model mapping cache management endpoints
This commit is contained in:
@@ -869,3 +869,310 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# ==================== 模型映射缓存管理 ====================
|
||||
|
||||
|
||||
@router.get("/model-mapping/stats")
|
||||
async def get_model_mapping_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取模型映射缓存统计信息
|
||||
|
||||
返回:
|
||||
- 缓存键数量
|
||||
- 缓存 TTL 配置
|
||||
- 各类型缓存数量
|
||||
"""
|
||||
adapter = AdminModelMappingCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping")
|
||||
async def clear_all_model_mapping_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
清除所有模型映射缓存
|
||||
|
||||
警告: 这会影响所有模型解析,请谨慎使用
|
||||
"""
|
||||
adapter = AdminClearAllModelMappingCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping/{model_name}")
|
||||
async def clear_model_mapping_cache_by_name(
|
||||
model_name: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
清除指定模型名称的映射缓存
|
||||
|
||||
参数:
|
||||
- model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
"""
|
||||
adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
import json
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.config.constants import CacheTTL
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
db = context.db
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return {
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"available": False,
|
||||
"message": "Redis 未启用,模型映射缓存不可用",
|
||||
},
|
||||
}
|
||||
|
||||
# 统计各类型缓存键数量
|
||||
model_id_keys = []
|
||||
global_model_id_keys = []
|
||||
global_model_name_keys = []
|
||||
global_model_resolve_keys = []
|
||||
provider_global_keys = []
|
||||
|
||||
# 扫描所有模型相关的缓存键
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("model:id:"):
|
||||
model_id_keys.append(key_str)
|
||||
elif key_str.startswith("model:provider_global:"):
|
||||
provider_global_keys.append(key_str)
|
||||
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("global_model:id:"):
|
||||
global_model_id_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:name:"):
|
||||
global_model_name_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:resolve:"):
|
||||
global_model_resolve_keys.append(key_str)
|
||||
|
||||
total_keys = (
|
||||
len(model_id_keys)
|
||||
+ len(global_model_id_keys)
|
||||
+ len(global_model_name_keys)
|
||||
+ len(global_model_resolve_keys)
|
||||
+ len(provider_global_keys)
|
||||
)
|
||||
|
||||
# 解析缓存内容,构建映射列表
|
||||
mappings = []
|
||||
unmapped_entries = []
|
||||
|
||||
for key in global_model_resolve_keys[:100]: # 最多处理 100 个
|
||||
mapping_name = key.replace("global_model:resolve:", "")
|
||||
try:
|
||||
cached_value = await redis.get(key)
|
||||
ttl = await redis.ttl(key)
|
||||
|
||||
if cached_value:
|
||||
cached_str = (
|
||||
cached_value.decode()
|
||||
if isinstance(cached_value, bytes)
|
||||
else cached_value
|
||||
)
|
||||
|
||||
if cached_str == "NOT_FOUND":
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "not_found",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
else:
|
||||
try:
|
||||
cached_data = json.loads(cached_str)
|
||||
global_model_id = cached_data.get("id")
|
||||
global_model_name = cached_data.get("name")
|
||||
global_model_display_name = cached_data.get("display_name")
|
||||
|
||||
# 跳过 mapping_name == global_model_name 的情况(直接匹配,不是映射)
|
||||
if mapping_name == global_model_name:
|
||||
continue
|
||||
|
||||
# 查询哪些 Provider 配置了这个映射名称
|
||||
provider_names = []
|
||||
if global_model_id:
|
||||
models = (
|
||||
db.query(Model, Provider)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Model.global_model_id == global_model_id,
|
||||
Model.is_active,
|
||||
Provider.is_active,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# 只显示配置了该映射名称的 Provider
|
||||
for model, provider in models:
|
||||
# 检查是否是主模型名称
|
||||
if model.provider_model_name == mapping_name:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
continue
|
||||
# 检查是否在别名列表中
|
||||
if model.provider_model_aliases:
|
||||
alias_names = [
|
||||
a.get("name")
|
||||
for a in model.provider_model_aliases
|
||||
if isinstance(a, dict)
|
||||
]
|
||||
if mapping_name in alias_names:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
provider_names = sorted(list(set(provider_names)))
|
||||
|
||||
mappings.append({
|
||||
"mapping_name": mapping_name,
|
||||
"global_model_name": global_model_name,
|
||||
"global_model_display_name": global_model_display_name,
|
||||
"providers": provider_names,
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "invalid",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"解析缓存键 {key} 失败: {e}")
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "error",
|
||||
"ttl": None,
|
||||
})
|
||||
|
||||
# 按 mapping_name 排序
|
||||
mappings.sort(key=lambda x: x["mapping_name"])
|
||||
|
||||
response_data = {
|
||||
"available": True,
|
||||
"ttl_seconds": CacheTTL.MODEL,
|
||||
"total_keys": total_keys,
|
||||
"breakdown": {
|
||||
"model_by_id": len(model_id_keys),
|
||||
"model_by_provider_global": len(provider_global_keys),
|
||||
"global_model_by_id": len(global_model_id_keys),
|
||||
"global_model_by_name": len(global_model_name_keys),
|
||||
"global_model_resolve": len(global_model_resolve_keys),
|
||||
},
|
||||
"mappings": mappings,
|
||||
"unmapped": unmapped_entries if unmapped_entries else None,
|
||||
}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_stats",
|
||||
total_keys=total_keys,
|
||||
)
|
||||
return {"status": "ok", "data": response_data}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"获取模型映射缓存统计失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
# 删除所有模型相关的缓存键
|
||||
keys_to_delete = []
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if keys_to_delete:
|
||||
deleted_count = await redis.delete(*keys_to_delete)
|
||||
|
||||
logger.warning(f"已清除所有模型映射缓存(管理员操作): {deleted_count} 个键")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_all",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除所有模型映射缓存",
|
||||
"deleted_count": deleted_count,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
||||
model_name: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_keys = []
|
||||
|
||||
# 清除 resolve 缓存
|
||||
resolve_key = f"global_model:resolve:{self.model_name}"
|
||||
if await redis.exists(resolve_key):
|
||||
await redis.delete(resolve_key)
|
||||
deleted_keys.append(resolve_key)
|
||||
|
||||
# 清除 name 缓存
|
||||
name_key = f"global_model:name:{self.model_name}"
|
||||
if await redis.exists(name_key):
|
||||
await redis.delete(name_key)
|
||||
deleted_keys.append(name_key)
|
||||
|
||||
logger.info(f"已清除模型映射缓存: model_name={self.model_name}, 删除键={deleted_keys}")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_by_name",
|
||||
model_name=self.model_name,
|
||||
deleted_keys=deleted_keys,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除模型 {self.model_name} 的映射缓存",
|
||||
"model_name": self.model_name,
|
||||
"deleted_keys": deleted_keys,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
Reference in New Issue
Block a user