4 Commits

10 changed files with 478 additions and 160 deletions

View File

@@ -12,8 +12,6 @@ services:
TZ: Asia/Shanghai TZ: Asia/Shanghai
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
ports:
- "${DB_PORT:-5432}:5432"
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"] test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s interval: 5s
@@ -27,8 +25,6 @@ services:
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD} command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
volumes: volumes:
- redis_data:/data - redis_data:/data
ports:
- "${REDIS_PORT:-6379}:6379"
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"] test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 5s interval: 5s

View File

@@ -290,6 +290,19 @@ export interface UnmappedEntry {
ttl: number | null ttl: number | null
} }
// Provider 模型映射缓存Redis 缓存)
export interface ProviderModelMapping {
provider_id: string
provider_name: string
global_model_id: string
global_model_name: string
global_model_display_name: string | null
provider_model_name: string
aliases: string[] | null
ttl: number | null
hit_count: number
}
export interface ModelMappingCacheStats { export interface ModelMappingCacheStats {
available: boolean available: boolean
message?: string message?: string
@@ -303,6 +316,7 @@ export interface ModelMappingCacheStats {
global_model_resolve: number global_model_resolve: number
} }
mappings?: ModelMappingItem[] mappings?: ModelMappingItem[]
provider_model_mappings?: ProviderModelMapping[] | null
unmapped?: UnmappedEntry[] | null unmapped?: UnmappedEntry[] | null
} }
@@ -337,5 +351,13 @@ export const modelMappingCacheApi = {
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> { async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`) const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
return response.data return response.data
},
/**
* 清除指定 Provider 和 GlobalModel 的映射缓存
*/
async clearProviderModel(providerId: string, globalModelId: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/provider/${providerId}/${globalModelId}`)
return response.data
} }
} }

View File

@@ -299,6 +299,26 @@ async function clearModelMappingByName(modelName: string) {
} }
} }
async function clearProviderModelMapping(providerId: string, globalModelId: string, displayName?: string) {
const confirmed = await showConfirm({
title: '确认清除',
message: `确定要清除 ${displayName || 'Provider 模型映射'} 的缓存吗?`,
confirmText: '确认清除',
variant: 'destructive'
})
if (!confirmed) return
try {
await modelMappingCacheApi.clearProviderModel(providerId, globalModelId)
showSuccess('已清除 Provider 模型映射缓存')
await fetchModelMappingStats()
} catch (error) {
showError('清除缓存失败')
log.error('清除 Provider 模型映射缓存失败', error)
}
}
function formatTTL(ttl: number | null): string { function formatTTL(ttl: number | null): string {
if (ttl === null || ttl < 0) return '-' if (ttl === null || ttl < 0) return '-'
if (ttl < 60) return `${ttl}s` if (ttl < 60) return `${ttl}s`
@@ -872,9 +892,125 @@ onBeforeUnmount(() => {
</div> </div>
</div> </div>
<!-- Provider 模型映射缓存 -->
<div
v-if="modelMappingStats?.available && modelMappingStats.provider_model_mappings && modelMappingStats.provider_model_mappings.length > 0"
class="border-t border-border/40"
>
<div class="px-6 py-3 text-xs text-muted-foreground border-b border-border/30 bg-muted/20">
Provider 模型映射缓存
</div>
<!-- 桌面端表格 -->
<Table class="hidden md:table">
<TableHeader>
<TableRow>
<TableHead class="w-[15%]">
提供商
</TableHead>
<TableHead class="w-[25%]">
请求名称
</TableHead>
<TableHead class="w-8 text-center" />
<TableHead class="w-[25%]">
映射模型
</TableHead>
<TableHead class="w-[10%] text-center">
剩余
</TableHead>
<TableHead class="w-[10%] text-center">
次数
</TableHead>
<TableHead class="w-[7%] text-right">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="index"
>
<TableRow
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`${index}-${aliasIndex}`"
>
<TableCell>
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
</TableCell>
<TableCell>
<span class="text-sm font-mono">{{ alias }}</span>
</TableCell>
<TableCell class="text-center">
<ArrowRight class="h-4 w-4 text-muted-foreground" />
</TableCell>
<TableCell>
<span class="text-sm font-mono font-medium">{{ mapping.provider_model_name }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-sm">{{ mapping.hit_count || 0 }}</span>
</TableCell>
<TableCell class="text-right">
<Button
size="icon"
variant="ghost"
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
</TableCell>
</TableRow>
</template>
</TableBody>
</Table>
<!-- 移动端卡片 -->
<div class="md:hidden divide-y divide-border/40">
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="`m-pm-${index}`"
>
<div
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`m-pm-${index}-${aliasIndex}`"
class="p-4 space-y-2"
>
<div class="flex items-center justify-between">
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
<div class="flex items-center gap-2">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
<span class="text-xs">{{ mapping.hit_count || 0 }}</span>
<Button
size="icon"
variant="ghost"
class="h-6 w-6 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3 w-3" />
</Button>
</div>
</div>
<div class="flex items-center gap-2 text-sm">
<span class="font-mono">{{ alias }}</span>
<ArrowRight class="h-3.5 w-3.5 shrink-0 text-muted-foreground/60" />
<span class="font-mono font-medium">{{ mapping.provider_model_name }}</span>
</div>
</div>
</template>
</div>
</div>
<!-- 无缓存状态 --> <!-- 无缓存状态 -->
<div <div
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0)" v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0) && (!modelMappingStats.provider_model_mappings || modelMappingStats.provider_model_mappings.length === 0)"
class="px-6 py-8 text-center text-sm text-muted-foreground" class="px-6 py-8 text-center text-sm text-muted-foreground"
> >
暂无模型解析缓存 暂无模型解析缓存

View File

@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.context import ApiRequestContext
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync from src.clients.redis_client import get_redis_client_sync
@@ -87,19 +88,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
# 2. 尝试作为 Username 查询 # 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first() user = db.query(User).filter(User.username == identifier).first()
if user: if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id return user.id
# 3. 尝试作为 Email 查询 # 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first() user = db.query(User).filter(User.email == identifier).first()
if user: if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id return user.id
# 4. 尝试作为 API Key ID 查询 # 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first() api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key: if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
return api_key.user_id return api_key.user_id
# 无法识别 # 无法识别
@@ -111,7 +112,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
async def get_cache_stats( async def get_cache_stats(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取缓存亲和性统计信息 获取缓存亲和性统计信息
@@ -131,7 +132,7 @@ async def get_user_affinity(
user_identifier: str, user_identifier: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
查询指定用户的所有缓存亲和性 查询指定用户的所有缓存亲和性
@@ -157,7 +158,7 @@ async def list_affinities(
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"), limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"), offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取所有缓存亲和性列表,可选按关键词过滤 获取所有缓存亲和性列表,可选按关键词过滤
@@ -173,7 +174,7 @@ async def clear_user_cache(
user_identifier: str, user_identifier: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear cache affinity for a specific user Clear cache affinity for a specific user
@@ -188,7 +189,7 @@ async def clear_user_cache(
async def clear_all_cache( async def clear_all_cache(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear all cache affinities Clear all cache affinities
@@ -203,7 +204,7 @@ async def clear_provider_cache(
provider_id: str, provider_id: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear cache affinities for a specific provider Clear cache affinities for a specific provider
@@ -218,7 +219,7 @@ async def clear_provider_cache(
async def get_cache_config( async def get_cache_config(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取缓存相关配置 获取缓存相关配置
@@ -234,7 +235,7 @@ async def get_cache_config(
async def get_cache_metrics( async def get_cache_metrics(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。 以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
""" """
@@ -246,7 +247,7 @@ async def get_cache_metrics(
class AdminCacheStatsAdapter(AdminApiAdapter): class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client) scheduler = await get_cache_aware_scheduler(redis_client)
@@ -266,7 +267,7 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
class AdminCacheMetricsAdapter(AdminApiAdapter): class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client) scheduler = await get_cache_aware_scheduler(redis_client)
@@ -391,7 +392,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
class AdminGetUserAffinityAdapter(AdminApiAdapter): class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str user_identifier: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
try: try:
user_id = resolve_user_identifier(db, self.user_identifier) user_id = resolve_user_identifier(db, self.user_identifier)
@@ -472,7 +473,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
limit: int limit: int
offset: int offset: int
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
if not redis_client: if not redis_client:
@@ -682,7 +683,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
class AdminClearUserCacheAdapter(AdminApiAdapter): class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str user_identifier: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
@@ -786,7 +787,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
class AdminClearAllCacheAdapter(AdminApiAdapter): class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client) affinity_mgr = await get_affinity_manager(redis_client)
@@ -806,7 +807,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
class AdminClearProviderCacheAdapter(AdminApiAdapter): class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str provider_id: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client) affinity_mgr = await get_affinity_manager(redis_client)
@@ -829,7 +830,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter): class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
@@ -878,7 +879,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
async def get_model_mapping_cache_stats( async def get_model_mapping_cache_stats(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取模型映射缓存统计信息 获取模型映射缓存统计信息
@@ -895,7 +896,7 @@ async def get_model_mapping_cache_stats(
async def clear_all_model_mapping_cache( async def clear_all_model_mapping_cache(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
清除所有模型映射缓存 清除所有模型映射缓存
@@ -910,7 +911,7 @@ async def clear_model_mapping_cache_by_name(
model_name: str, model_name: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
清除指定模型名称的映射缓存 清除指定模型名称的映射缓存
@@ -921,8 +922,28 @@ async def clear_model_mapping_cache_by_name(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
async def clear_provider_model_mapping_cache(
provider_id: str,
global_model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
清除指定 Provider 和 GlobalModel 的模型映射缓存
参数:
- provider_id: Provider ID
- global_model_id: GlobalModel ID
"""
adapter = AdminClearProviderModelMappingCacheAdapter(
provider_id=provider_id, global_model_id=global_model_id
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter): class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
import json import json
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
@@ -955,7 +976,9 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if key_str.startswith("model:id:"): if key_str.startswith("model:id:"):
model_id_keys.append(key_str) model_id_keys.append(key_str)
elif key_str.startswith("model:provider_global:"): elif key_str.startswith("model:provider_global:"):
provider_global_keys.append(key_str) # 过滤掉 hits 统计键,只保留实际的缓存键
if not key_str.startswith("model:provider_global:hits:"):
provider_global_keys.append(key_str)
async for key in redis.scan_iter(match="global_model:*", count=100): async for key in redis.scan_iter(match="global_model:*", count=100):
key_str = key.decode() if isinstance(key, bytes) else key key_str = key.decode() if isinstance(key, bytes) else key
@@ -1067,6 +1090,85 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
# 按 mapping_name 排序 # 按 mapping_name 排序
mappings.sort(key=lambda x: x["mapping_name"]) mappings.sort(key=lambda x: x["mapping_name"])
# 3. 解析 provider_global 缓存Provider 级别的模型解析缓存)
provider_model_mappings = []
# 预加载 Provider 和 GlobalModel 数据
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
for key in provider_global_keys[:100]: # 最多处理 100 个
# key 格式: model:provider_global:{provider_id}:{global_model_id}
try:
parts = key.replace("model:provider_global:", "").split(":")
if len(parts) != 2:
continue
provider_id, global_model_id = parts
cached_value = await redis.get(key)
ttl = await redis.ttl(key)
# 获取命中次数
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
hit_count_raw = await redis.get(hit_count_key)
hit_count = int(hit_count_raw) if hit_count_raw else 0
if cached_value:
cached_str = (
cached_value.decode()
if isinstance(cached_value, bytes)
else cached_value
)
try:
cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name")
provider_model_aliases = cached_data.get("provider_model_aliases", [])
# 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id)
global_model = global_model_map.get(global_model_id)
if provider and global_model:
# 提取别名名称
alias_names = []
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict) and alias_entry.get("name"):
alias_names.append(alias_entry["name"])
# provider_model_name 为空时跳过
if not provider_model_name:
continue
# 只显示有实际映射的条目:
# 1. 全局模型名 != Provider 模型名(模型名称映射)
# 2. 或者有别名配置
has_name_mapping = global_model.name != provider_model_name
has_aliases = len(alias_names) > 0
if has_name_mapping or has_aliases:
# 构建用于展示的别名列表
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
display_aliases = alias_names if alias_names else [global_model.name]
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,
"provider_model_name": provider_model_name,
"aliases": display_aliases,
"ttl": ttl if ttl > 0 else None,
"hit_count": hit_count,
})
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
# 按 provider_name + global_model_name 排序
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
response_data = { response_data = {
"available": True, "available": True,
"ttl_seconds": CacheTTL.MODEL, "ttl_seconds": CacheTTL.MODEL,
@@ -1079,6 +1181,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
"global_model_resolve": len(global_model_resolve_keys), "global_model_resolve": len(global_model_resolve_keys),
}, },
"mappings": mappings, "mappings": mappings,
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
"unmapped": unmapped_entries if unmapped_entries else None, "unmapped": unmapped_entries if unmapped_entries else None,
} }
@@ -1094,7 +1197,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter): class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
try: try:
@@ -1136,7 +1239,7 @@ class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter): class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
model_name: str model_name: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
try: try:
@@ -1176,3 +1279,55 @@ class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
except Exception as exc: except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}") logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}") raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
provider_id: str
global_model_id: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_keys = []
# 清除 provider_global 缓存
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
if await redis.exists(provider_global_key):
await redis.delete(provider_global_key)
deleted_keys.append(provider_global_key)
# 清除对应的 hit_count 缓存
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
if await redis.exists(hit_count_key):
await redis.delete(hit_count_key)
deleted_keys.append(hit_count_key)
logger.info(
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
)
context.add_audit_metadata(
action="provider_model_mapping_cache_clear",
provider_id=self.provider_id,
global_model_id=self.global_model_id,
deleted_keys=deleted_keys,
)
return {
"status": "ok",
"message": "已清除 Provider 模型映射缓存",
"provider_id": self.provider_id,
"global_model_id": self.global_model_id,
"deleted_keys": deleted_keys,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")

View File

@@ -395,3 +395,24 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流 # 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update()) asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
Args:
message: 错误消息前缀
error: 异常对象
"""
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:
# 未知异常:完整堆栈
logger.exception(f"{message}: {error}")

View File

@@ -382,7 +382,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
) )
except Exception as e: except Exception as e:
logger.exception(f"流式请求失败: {e}") self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body) await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise raise

View File

@@ -413,20 +413,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
) )
except Exception as e: except Exception as e:
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈 self._log_request_error("流式请求失败", e)
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
else:
# 未知异常:完整堆栈
logger.exception(f"流式请求失败: {e}")
await self._record_stream_failure(ctx, e, original_headers, original_request_body) await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise raise

View File

@@ -120,6 +120,33 @@ class CacheService:
logger.warning(f"缓存检查失败: {key} - {e}") logger.warning(f"缓存检查失败: {key} - {e}")
return False return False
@staticmethod
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
"""
递增缓存值
Args:
key: 缓存键
ttl_seconds: 可选,如果提供则刷新 TTL
Returns:
递增后的值,如果失败返回 0
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return 0
result = await redis.incr(key)
# 如果提供了 TTL刷新过期时间
if ttl_seconds is not None:
await redis.expire(key, ttl_seconds)
return result
except Exception as e:
logger.warning(f"缓存递增失败: {key} - {e}")
return 0
# 缓存键前缀 # 缓存键前缀
class CacheKeys: class CacheKeys:

View File

@@ -2,11 +2,9 @@
Model 映射缓存服务 - 减少模型查询 Model 映射缓存服务 - 减少模型查询
""" """
import json
import time import time
from typing import Optional from typing import List, Optional
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
@@ -106,6 +104,7 @@ class ModelCacheService:
Model 对象或 None Model 对象或 None
""" """
cache_key = f"model:provider_global:{provider_id}:{global_model_id}" cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取 # 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key) cached_data = await CacheService.get(cache_key)
@@ -113,6 +112,8 @@ class ModelCacheService:
logger.debug( logger.debug(
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
) )
# 递增命中计数,同时刷新 TTL
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
return ModelCacheService._dict_to_model(cached_data) return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库 # 2. 缓存未命中,查询数据库
@@ -130,6 +131,8 @@ class ModelCacheService:
if model: if model:
model_dict = ModelCacheService._model_to_dict(model) model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL) await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
# 重置命中计数新缓存从1开始
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug( logger.debug(
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
) )
@@ -189,9 +192,10 @@ class ModelCacheService:
# 清除 model:id 缓存 # 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}") await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数) # 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
if provider_id and global_model_id: if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}") await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
logger.debug( logger.debug(
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..." f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
) )
@@ -230,16 +234,20 @@ class ModelCacheService:
db: Session, model_name: str db: Session, model_name: str
) -> Optional[GlobalModel]: ) -> Optional[GlobalModel]:
""" """
通过名称或映射解析 GlobalModel带缓存,支持映射匹配 通过名称解析 GlobalModel带缓存
查找顺序: 查找顺序:
1. 检查缓存 1. 检查缓存
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases 2. 通过 provider_model_name 匹配(查询 Model 表
3. 直接匹配 GlobalModel.name兜底 3. 直接匹配 GlobalModel.name兜底
注意:此方法不使用 provider_model_aliases 进行全局解析。
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
由 resolve_provider_model() 处理。
Args: Args:
db: 数据库会话 db: 数据库会话
model_name: 模型名称(可以是 GlobalModel.name 或映射名称 model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name
Returns: Returns:
GlobalModel 对象或 None GlobalModel 对象或 None
@@ -273,116 +281,53 @@ class ModelCacheService:
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}") logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data) return ModelCacheService._dict_to_global_model(cached_data)
# 2. 优先通过 provider_model_name 和映射名称匹配Provider 配置优先级最高 # 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases
from sqlalchemy import or_ # 重要provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 别名配置的影响
# 例如Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
from src.models.database import Provider from src.models.database import Provider
# 构建精确的映射匹配条件 models_with_global = (
# 注意provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符 db.query(Model, GlobalModel)
# 对于 SQLite会在 Python 层面进行过滤 .join(Provider, Model.provider_id == Provider.id)
try: .join(GlobalModel, Model.global_model_id == GlobalModel.id)
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效) .filter(
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入 Provider.is_active == True,
jsonb_pattern = json.dumps([{"name": normalized_name}]) Model.is_active == True,
models_with_global = ( GlobalModel.is_active == True,
db.query(Model, GlobalModel) Model.provider_model_name == normalized_name,
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
or_(
Model.provider_model_name == normalized_name,
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
Model.provider_model_aliases.op("@>")(jsonb_pattern),
),
)
.all()
)
except (OperationalError, ProgrammingError) as e:
# JSONB 操作符不支持(如 SQLite回退到加载匹配 provider_model_name 的 Model
# 并在 Python 层过滤 aliases
logger.debug(
f"JSONB 查询失败,回退到 Python 过滤: {e}",
)
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
) )
.all()
)
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)} # 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
# 使用字典去重,同一个 Model 只保留优先级最高的匹配 matched_global_models: List[GlobalModel] = []
matched_models_dict = {} seen_global_model_ids: set[str] = set()
# 遍历查询结果进行匹配
for model, gm in models_with_global: for model, gm in models_with_global:
key = (model.id, gm.id) if gm.id not in seen_global_model_ids:
seen_global_model_ids.add(gm.id)
# 检查 provider_model_aliases 是否匹配(优先级更高) matched_global_models.append(gm)
if model.provider_model_aliases: logger.debug(
for alias_entry in model.provider_model_aliases: f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
if isinstance(alias_entry, dict): f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
alias_name = alias_entry.get("name", "").strip()
if alias_name == normalized_name:
# alias 优先级为 0最高覆盖任何已存在的匹配
matched_models_dict[key] = (gm, "alias", 0)
logger.debug(
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
break
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
if model.provider_model_name == normalized_name:
# provider_model_name 优先级为 1兜底只在没有 alias 匹配时使用
if key not in matched_models_dict:
matched_models_dict[key] = (gm, "provider_model_name", 1)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
# 如果通过 provider_model_name/alias 找到了,直接返回
if matched_models_dict:
# 转换为列表并排序:按 priorityalias=0 优先)、然后按 GlobalModel.name
matched_global_models = [
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
]
matched_global_models.sort(
key=lambda item: (
0 if item[1] == "alias" else 1, # alias 优先
item[0].name # 同优先级按名称排序(确定性)
) )
)
# 记录解析方式 # 如果通过 provider_model_name 找到了,返回
resolution_method = matched_global_models[0][1] if matched_global_models:
resolution_method = "provider_model_name"
if len(matched_global_models) > 1: if len(matched_global_models) > 1:
# 检测到冲突 # 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name
unique_models = {gm.id: gm for gm, _ in matched_global_models} model_names = [gm.name for gm in matched_global_models if gm.name]
if len(unique_models) > 1: logger.warning(
model_names = [gm.name for gm in unique_models.values()] f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
logger.warning( f"{', '.join(model_names)},使用第一个匹配结果"
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: " )
f"{', '.join(model_names)},使用第一个匹配结果" # 记录冲突指标
) model_mapping_conflict_total.inc()
# 记录冲突指标
model_mapping_conflict_total.inc()
# 返回第一个匹配的 GlobalModel # 返回第一个匹配的 GlobalModel
result_global_model: GlobalModel = matched_global_models[0][0] result_global_model = matched_global_models[0]
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model) global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
await CacheService.set( await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL

View File

@@ -6,7 +6,7 @@
- 根据 API 格式或端点配置生成请求 URL - 根据 API 格式或端点配置生成请求 URL
""" """
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
from src.core.enums import APIFormat from src.core.enums import APIFormat
from src.core.logger import logger from src.core.logger import logger
if TYPE_CHECKING:
from src.models.database import ProviderAPIKey, ProviderEndpoint
def build_provider_headers( def build_provider_headers(
endpoint, endpoint: "ProviderEndpoint",
key, key: "ProviderAPIKey",
original_headers: Optional[Dict[str, str]] = None, original_headers: Optional[Dict[str, str]] = None,
*, *,
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
@@ -28,7 +31,8 @@ def build_provider_headers(
""" """
headers: Dict[str, str] = {} headers: Dict[str, str] = {}
decrypted_key = crypto_service.decrypt(key.api_key) # api_key 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
# 根据 API 格式自动选择认证头 # 根据 API 格式自动选择认证头
api_format = getattr(endpoint, "api_format", None) api_format = getattr(endpoint, "api_format", None)
@@ -68,8 +72,32 @@ def build_provider_headers(
return headers return headers
def _normalize_base_url(base_url: str, path: str) -> str:
"""
规范化 base_url去除末尾的斜杠和可能与 path 重复的版本前缀。
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
避免拼接出 /v1/v1/messages 这样的重复路径。
兼容用户填写的各种格式:
- https://api.example.com
- https://api.example.com/
- https://api.example.com/v1
- https://api.example.com/v1/
"""
base = base_url.rstrip("/")
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
# 例如base="/v1", path="/v1/messages" -> 去除 /v1
# 例如base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
if base.endswith(suffix) and path.startswith(suffix):
base = base[: -len(suffix)]
break
return base
def build_provider_url( def build_provider_url(
endpoint, endpoint: "ProviderEndpoint",
*, *,
query_params: Optional[Dict[str, Any]] = None, query_params: Optional[Dict[str, Any]] = None,
path_params: Optional[Dict[str, Any]] = None, path_params: Optional[Dict[str, Any]] = None,
@@ -88,8 +116,6 @@ def build_provider_url(
path_params: 路径模板参数 (如 {model}) path_params: 路径模板参数 (如 {model})
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法 is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
""" """
base = endpoint.base_url.rstrip("/")
# 准备路径参数,添加 Gemini API 所需的 action 参数 # 准备路径参数,添加 Gemini API 所需的 action 参数
effective_path_params = dict(path_params) if path_params else {} effective_path_params = dict(path_params) if path_params else {}
@@ -123,6 +149,9 @@ def build_provider_url(
if not path.startswith("/"): if not path.startswith("/"):
path = f"/{path}" path = f"/{path}"
# 先确定 path再根据 path 规范化 base_url
# base_url 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
url = f"{base}{path}" url = f"{base}{path}"
# 添加查询参数 # 添加查询参数
@@ -134,7 +163,7 @@ def build_provider_url(
return url return url
def _resolve_default_path(api_format) -> str: def _resolve_default_path(api_format: Optional[str]) -> str:
""" """
根据 API 格式返回默认路径 根据 API 格式返回默认路径
""" """