mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 02:02:27 +08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3a69a6160 | ||
|
|
adcdb73d29 | ||
|
|
cf67160821 | ||
|
|
718f56ba75 | ||
|
|
d87de10f62 |
@@ -12,8 +12,6 @@ services:
|
|||||||
TZ: Asia/Shanghai
|
TZ: Asia/Shanghai
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
ports:
|
|
||||||
- "${DB_PORT:-5432}:5432"
|
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
interval: 5s
|
interval: 5s
|
||||||
@@ -27,8 +25,6 @@ services:
|
|||||||
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
|
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
ports:
|
|
||||||
- "${REDIS_PORT:-6379}:6379"
|
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||||
interval: 5s
|
interval: 5s
|
||||||
|
|||||||
@@ -290,6 +290,19 @@ export interface UnmappedEntry {
|
|||||||
ttl: number | null
|
ttl: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Provider 模型映射缓存(Redis 缓存)
|
||||||
|
export interface ProviderModelMapping {
|
||||||
|
provider_id: string
|
||||||
|
provider_name: string
|
||||||
|
global_model_id: string
|
||||||
|
global_model_name: string
|
||||||
|
global_model_display_name: string | null
|
||||||
|
provider_model_name: string
|
||||||
|
aliases: string[] | null
|
||||||
|
ttl: number | null
|
||||||
|
hit_count: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface ModelMappingCacheStats {
|
export interface ModelMappingCacheStats {
|
||||||
available: boolean
|
available: boolean
|
||||||
message?: string
|
message?: string
|
||||||
@@ -303,6 +316,7 @@ export interface ModelMappingCacheStats {
|
|||||||
global_model_resolve: number
|
global_model_resolve: number
|
||||||
}
|
}
|
||||||
mappings?: ModelMappingItem[]
|
mappings?: ModelMappingItem[]
|
||||||
|
provider_model_mappings?: ProviderModelMapping[] | null
|
||||||
unmapped?: UnmappedEntry[] | null
|
unmapped?: UnmappedEntry[] | null
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,5 +351,13 @@ export const modelMappingCacheApi = {
|
|||||||
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
|
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
|
||||||
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
|
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除指定 Provider 和 GlobalModel 的映射缓存
|
||||||
|
*/
|
||||||
|
async clearProviderModel(providerId: string, globalModelId: string): Promise<ClearModelMappingCacheResponse> {
|
||||||
|
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/provider/${providerId}/${globalModelId}`)
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -299,6 +299,26 @@ async function clearModelMappingByName(modelName: string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function clearProviderModelMapping(providerId: string, globalModelId: string, displayName?: string) {
|
||||||
|
const confirmed = await showConfirm({
|
||||||
|
title: '确认清除',
|
||||||
|
message: `确定要清除 ${displayName || 'Provider 模型映射'} 的缓存吗?`,
|
||||||
|
confirmText: '确认清除',
|
||||||
|
variant: 'destructive'
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!confirmed) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
await modelMappingCacheApi.clearProviderModel(providerId, globalModelId)
|
||||||
|
showSuccess('已清除 Provider 模型映射缓存')
|
||||||
|
await fetchModelMappingStats()
|
||||||
|
} catch (error) {
|
||||||
|
showError('清除缓存失败')
|
||||||
|
log.error('清除 Provider 模型映射缓存失败', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function formatTTL(ttl: number | null): string {
|
function formatTTL(ttl: number | null): string {
|
||||||
if (ttl === null || ttl < 0) return '-'
|
if (ttl === null || ttl < 0) return '-'
|
||||||
if (ttl < 60) return `${ttl}s`
|
if (ttl < 60) return `${ttl}s`
|
||||||
@@ -872,9 +892,125 @@ onBeforeUnmount(() => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Provider 模型映射缓存 -->
|
||||||
|
<div
|
||||||
|
v-if="modelMappingStats?.available && modelMappingStats.provider_model_mappings && modelMappingStats.provider_model_mappings.length > 0"
|
||||||
|
class="border-t border-border/40"
|
||||||
|
>
|
||||||
|
<div class="px-6 py-3 text-xs text-muted-foreground border-b border-border/30 bg-muted/20">
|
||||||
|
Provider 模型映射缓存
|
||||||
|
</div>
|
||||||
|
<!-- 桌面端表格 -->
|
||||||
|
<Table class="hidden md:table">
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead class="w-[15%]">
|
||||||
|
提供商
|
||||||
|
</TableHead>
|
||||||
|
<TableHead class="w-[25%]">
|
||||||
|
请求名称
|
||||||
|
</TableHead>
|
||||||
|
<TableHead class="w-8 text-center" />
|
||||||
|
<TableHead class="w-[25%]">
|
||||||
|
映射模型
|
||||||
|
</TableHead>
|
||||||
|
<TableHead class="w-[10%] text-center">
|
||||||
|
剩余
|
||||||
|
</TableHead>
|
||||||
|
<TableHead class="w-[10%] text-center">
|
||||||
|
次数
|
||||||
|
</TableHead>
|
||||||
|
<TableHead class="w-[7%] text-right">
|
||||||
|
操作
|
||||||
|
</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
<template
|
||||||
|
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
|
||||||
|
:key="index"
|
||||||
|
>
|
||||||
|
<TableRow
|
||||||
|
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
|
||||||
|
:key="`${index}-${aliasIndex}`"
|
||||||
|
>
|
||||||
|
<TableCell>
|
||||||
|
<Badge variant="outline" class="text-xs">
|
||||||
|
{{ mapping.provider_name }}
|
||||||
|
</Badge>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<span class="text-sm font-mono">{{ alias }}</span>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell class="text-center">
|
||||||
|
<ArrowRight class="h-4 w-4 text-muted-foreground" />
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<span class="text-sm font-mono font-medium">{{ mapping.provider_model_name }}</span>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell class="text-center">
|
||||||
|
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell class="text-center">
|
||||||
|
<span class="text-sm">{{ mapping.hit_count || 0 }}</span>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell class="text-right">
|
||||||
|
<Button
|
||||||
|
size="icon"
|
||||||
|
variant="ghost"
|
||||||
|
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
|
||||||
|
title="清除缓存"
|
||||||
|
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
|
||||||
|
>
|
||||||
|
<Trash2 class="h-3.5 w-3.5" />
|
||||||
|
</Button>
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
</template>
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
<!-- 移动端卡片 -->
|
||||||
|
<div class="md:hidden divide-y divide-border/40">
|
||||||
|
<template
|
||||||
|
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
|
||||||
|
:key="`m-pm-${index}`"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
|
||||||
|
:key="`m-pm-${index}-${aliasIndex}`"
|
||||||
|
class="p-4 space-y-2"
|
||||||
|
>
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<Badge variant="outline" class="text-xs">
|
||||||
|
{{ mapping.provider_name }}
|
||||||
|
</Badge>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||||
|
<span class="text-xs">{{ mapping.hit_count || 0 }}次</span>
|
||||||
|
<Button
|
||||||
|
size="icon"
|
||||||
|
variant="ghost"
|
||||||
|
class="h-6 w-6 text-muted-foreground/70 hover:text-destructive"
|
||||||
|
title="清除缓存"
|
||||||
|
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
|
||||||
|
>
|
||||||
|
<Trash2 class="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center gap-2 text-sm">
|
||||||
|
<span class="font-mono">{{ alias }}</span>
|
||||||
|
<ArrowRight class="h-3.5 w-3.5 shrink-0 text-muted-foreground/60" />
|
||||||
|
<span class="font-mono font-medium">{{ mapping.provider_model_name }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 无缓存状态 -->
|
<!-- 无缓存状态 -->
|
||||||
<div
|
<div
|
||||||
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0)"
|
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0) && (!modelMappingStats.provider_model_mappings || modelMappingStats.provider_model_mappings.length === 0)"
|
||||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||||
>
|
>
|
||||||
暂无模型解析缓存
|
暂无模型解析缓存
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
|
from src.api.base.context import ApiRequestContext
|
||||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
|
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
|
||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
from src.clients.redis_client import get_redis_client_sync
|
from src.clients.redis_client import get_redis_client_sync
|
||||||
@@ -87,19 +88,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
|||||||
# 2. 尝试作为 Username 查询
|
# 2. 尝试作为 Username 查询
|
||||||
user = db.query(User).filter(User.username == identifier).first()
|
user = db.query(User).filter(User.username == identifier).first()
|
||||||
if user:
|
if user:
|
||||||
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
|
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
|
||||||
return user.id
|
return user.id
|
||||||
|
|
||||||
# 3. 尝试作为 Email 查询
|
# 3. 尝试作为 Email 查询
|
||||||
user = db.query(User).filter(User.email == identifier).first()
|
user = db.query(User).filter(User.email == identifier).first()
|
||||||
if user:
|
if user:
|
||||||
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
|
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
|
||||||
return user.id
|
return user.id
|
||||||
|
|
||||||
# 4. 尝试作为 API Key ID 查询
|
# 4. 尝试作为 API Key ID 查询
|
||||||
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
|
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
|
||||||
if api_key:
|
if api_key:
|
||||||
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
|
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
|
||||||
return api_key.user_id
|
return api_key.user_id
|
||||||
|
|
||||||
# 无法识别
|
# 无法识别
|
||||||
@@ -111,7 +112,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
|||||||
async def get_cache_stats(
|
async def get_cache_stats(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
获取缓存亲和性统计信息
|
获取缓存亲和性统计信息
|
||||||
|
|
||||||
@@ -131,7 +132,7 @@ async def get_user_affinity(
|
|||||||
user_identifier: str,
|
user_identifier: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
查询指定用户的所有缓存亲和性
|
查询指定用户的所有缓存亲和性
|
||||||
|
|
||||||
@@ -157,7 +158,7 @@ async def list_affinities(
|
|||||||
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
|
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
|
||||||
offset: int = Query(0, ge=0, description="偏移量"),
|
offset: int = Query(0, ge=0, description="偏移量"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
获取所有缓存亲和性列表,可选按关键词过滤
|
获取所有缓存亲和性列表,可选按关键词过滤
|
||||||
|
|
||||||
@@ -173,7 +174,7 @@ async def clear_user_cache(
|
|||||||
user_identifier: str,
|
user_identifier: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Clear cache affinity for a specific user
|
Clear cache affinity for a specific user
|
||||||
|
|
||||||
@@ -188,7 +189,7 @@ async def clear_user_cache(
|
|||||||
async def clear_all_cache(
|
async def clear_all_cache(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Clear all cache affinities
|
Clear all cache affinities
|
||||||
|
|
||||||
@@ -203,7 +204,7 @@ async def clear_provider_cache(
|
|||||||
provider_id: str,
|
provider_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Clear cache affinities for a specific provider
|
Clear cache affinities for a specific provider
|
||||||
|
|
||||||
@@ -218,7 +219,7 @@ async def clear_provider_cache(
|
|||||||
async def get_cache_config(
|
async def get_cache_config(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
获取缓存相关配置
|
获取缓存相关配置
|
||||||
|
|
||||||
@@ -234,7 +235,7 @@ async def get_cache_config(
|
|||||||
async def get_cache_metrics(
|
async def get_cache_metrics(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
||||||
"""
|
"""
|
||||||
@@ -246,7 +247,7 @@ async def get_cache_metrics(
|
|||||||
|
|
||||||
|
|
||||||
class AdminCacheStatsAdapter(AdminApiAdapter):
|
class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||||
@@ -266,7 +267,7 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
|
|
||||||
class AdminCacheMetricsAdapter(AdminApiAdapter):
|
class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||||
@@ -391,7 +392,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
|
|||||||
class AdminGetUserAffinityAdapter(AdminApiAdapter):
|
class AdminGetUserAffinityAdapter(AdminApiAdapter):
|
||||||
user_identifier: str
|
user_identifier: str
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
try:
|
try:
|
||||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||||
@@ -472,7 +473,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
|||||||
limit: int
|
limit: int
|
||||||
offset: int
|
offset: int
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
if not redis_client:
|
if not redis_client:
|
||||||
@@ -682,7 +683,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
|||||||
class AdminClearUserCacheAdapter(AdminApiAdapter):
|
class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||||
user_identifier: str
|
user_identifier: str
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
@@ -786,7 +787,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
|
|
||||||
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
affinity_mgr = await get_affinity_manager(redis_client)
|
affinity_mgr = await get_affinity_manager(redis_client)
|
||||||
@@ -806,7 +807,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
|
|||||||
class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
affinity_mgr = await get_affinity_manager(redis_client)
|
affinity_mgr = await get_affinity_manager(redis_client)
|
||||||
@@ -829,7 +830,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
|
|
||||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||||
@@ -878,7 +879,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
|||||||
async def get_model_mapping_cache_stats(
|
async def get_model_mapping_cache_stats(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
获取模型映射缓存统计信息
|
获取模型映射缓存统计信息
|
||||||
|
|
||||||
@@ -895,7 +896,7 @@ async def get_model_mapping_cache_stats(
|
|||||||
async def clear_all_model_mapping_cache(
|
async def clear_all_model_mapping_cache(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
清除所有模型映射缓存
|
清除所有模型映射缓存
|
||||||
|
|
||||||
@@ -910,7 +911,7 @@ async def clear_model_mapping_cache_by_name(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
清除指定模型名称的映射缓存
|
清除指定模型名称的映射缓存
|
||||||
|
|
||||||
@@ -921,8 +922,28 @@ async def clear_model_mapping_cache_by_name(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
|
||||||
|
async def clear_provider_model_mapping_cache(
|
||||||
|
provider_id: str,
|
||||||
|
global_model_id: str,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
清除指定 Provider 和 GlobalModel 的模型映射缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- provider_id: Provider ID
|
||||||
|
- global_model_id: GlobalModel ID
|
||||||
|
"""
|
||||||
|
adapter = AdminClearProviderModelMappingCacheAdapter(
|
||||||
|
provider_id=provider_id, global_model_id=global_model_id
|
||||||
|
)
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
@@ -955,7 +976,9 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
if key_str.startswith("model:id:"):
|
if key_str.startswith("model:id:"):
|
||||||
model_id_keys.append(key_str)
|
model_id_keys.append(key_str)
|
||||||
elif key_str.startswith("model:provider_global:"):
|
elif key_str.startswith("model:provider_global:"):
|
||||||
provider_global_keys.append(key_str)
|
# 过滤掉 hits 统计键,只保留实际的缓存键
|
||||||
|
if not key_str.startswith("model:provider_global:hits:"):
|
||||||
|
provider_global_keys.append(key_str)
|
||||||
|
|
||||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||||
key_str = key.decode() if isinstance(key, bytes) else key
|
key_str = key.decode() if isinstance(key, bytes) else key
|
||||||
@@ -1067,6 +1090,85 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
# 按 mapping_name 排序
|
# 按 mapping_name 排序
|
||||||
mappings.sort(key=lambda x: x["mapping_name"])
|
mappings.sort(key=lambda x: x["mapping_name"])
|
||||||
|
|
||||||
|
# 3. 解析 provider_global 缓存(Provider 级别的模型解析缓存)
|
||||||
|
provider_model_mappings = []
|
||||||
|
# 预加载 Provider 和 GlobalModel 数据
|
||||||
|
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
|
||||||
|
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
|
||||||
|
|
||||||
|
for key in provider_global_keys[:100]: # 最多处理 100 个
|
||||||
|
# key 格式: model:provider_global:{provider_id}:{global_model_id}
|
||||||
|
try:
|
||||||
|
parts = key.replace("model:provider_global:", "").split(":")
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
provider_id, global_model_id = parts
|
||||||
|
|
||||||
|
cached_value = await redis.get(key)
|
||||||
|
ttl = await redis.ttl(key)
|
||||||
|
|
||||||
|
# 获取命中次数
|
||||||
|
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
|
||||||
|
hit_count_raw = await redis.get(hit_count_key)
|
||||||
|
hit_count = int(hit_count_raw) if hit_count_raw else 0
|
||||||
|
|
||||||
|
if cached_value:
|
||||||
|
cached_str = (
|
||||||
|
cached_value.decode()
|
||||||
|
if isinstance(cached_value, bytes)
|
||||||
|
else cached_value
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cached_data = json.loads(cached_str)
|
||||||
|
provider_model_name = cached_data.get("provider_model_name")
|
||||||
|
provider_model_aliases = cached_data.get("provider_model_aliases", [])
|
||||||
|
|
||||||
|
# 获取 Provider 和 GlobalModel 信息
|
||||||
|
provider = provider_map.get(provider_id)
|
||||||
|
global_model = global_model_map.get(global_model_id)
|
||||||
|
|
||||||
|
if provider and global_model:
|
||||||
|
# 提取别名名称
|
||||||
|
alias_names = []
|
||||||
|
if provider_model_aliases:
|
||||||
|
for alias_entry in provider_model_aliases:
|
||||||
|
if isinstance(alias_entry, dict) and alias_entry.get("name"):
|
||||||
|
alias_names.append(alias_entry["name"])
|
||||||
|
|
||||||
|
# provider_model_name 为空时跳过
|
||||||
|
if not provider_model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 只显示有实际映射的条目:
|
||||||
|
# 1. 全局模型名 != Provider 模型名(模型名称映射)
|
||||||
|
# 2. 或者有别名配置
|
||||||
|
has_name_mapping = global_model.name != provider_model_name
|
||||||
|
has_aliases = len(alias_names) > 0
|
||||||
|
|
||||||
|
if has_name_mapping or has_aliases:
|
||||||
|
# 构建用于展示的别名列表
|
||||||
|
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
|
||||||
|
display_aliases = alias_names if alias_names else [global_model.name]
|
||||||
|
|
||||||
|
provider_model_mappings.append({
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_name": provider.display_name or provider.name,
|
||||||
|
"global_model_id": global_model_id,
|
||||||
|
"global_model_name": global_model.name,
|
||||||
|
"global_model_display_name": global_model.display_name,
|
||||||
|
"provider_model_name": provider_model_name,
|
||||||
|
"aliases": display_aliases,
|
||||||
|
"ttl": ttl if ttl > 0 else None,
|
||||||
|
"hit_count": hit_count,
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
|
||||||
|
|
||||||
|
# 按 provider_name + global_model_name 排序
|
||||||
|
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"available": True,
|
"available": True,
|
||||||
"ttl_seconds": CacheTTL.MODEL,
|
"ttl_seconds": CacheTTL.MODEL,
|
||||||
@@ -1079,6 +1181,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
"global_model_resolve": len(global_model_resolve_keys),
|
"global_model_resolve": len(global_model_resolve_keys),
|
||||||
},
|
},
|
||||||
"mappings": mappings,
|
"mappings": mappings,
|
||||||
|
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
|
||||||
"unmapped": unmapped_entries if unmapped_entries else None,
|
"unmapped": unmapped_entries if unmapped_entries else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1094,7 +1197,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
|
|
||||||
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1136,7 +1239,7 @@ class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
|||||||
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1176,3 +1279,55 @@ class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
|
||||||
|
provider_id: str
|
||||||
|
global_model_id: str
|
||||||
|
|
||||||
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_client(require_redis=False)
|
||||||
|
if not redis:
|
||||||
|
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||||
|
|
||||||
|
deleted_keys = []
|
||||||
|
|
||||||
|
# 清除 provider_global 缓存
|
||||||
|
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
|
||||||
|
if await redis.exists(provider_global_key):
|
||||||
|
await redis.delete(provider_global_key)
|
||||||
|
deleted_keys.append(provider_global_key)
|
||||||
|
|
||||||
|
# 清除对应的 hit_count 缓存
|
||||||
|
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
|
||||||
|
if await redis.exists(hit_count_key):
|
||||||
|
await redis.delete(hit_count_key)
|
||||||
|
deleted_keys.append(hit_count_key)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
|
||||||
|
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
|
||||||
|
)
|
||||||
|
context.add_audit_metadata(
|
||||||
|
action="provider_model_mapping_cache_clear",
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
global_model_id=self.global_model_id,
|
||||||
|
deleted_keys=deleted_keys,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"message": "已清除 Provider 模型映射缓存",
|
||||||
|
"provider_id": self.provider_id,
|
||||||
|
"global_model_id": self.global_model_id,
|
||||||
|
"deleted_keys": deleted_keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||||
|
|||||||
@@ -395,3 +395,24 @@ class BaseMessageHandler:
|
|||||||
|
|
||||||
# 创建后台任务,不阻塞当前流
|
# 创建后台任务,不阻塞当前流
|
||||||
asyncio.create_task(_do_update())
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
|
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||||
|
"""记录请求错误日志,对业务异常不打印堆栈
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 错误消息前缀
|
||||||
|
error: 异常对象
|
||||||
|
"""
|
||||||
|
from src.core.exceptions import (
|
||||||
|
ProviderException,
|
||||||
|
QuotaExceededException,
|
||||||
|
RateLimitException,
|
||||||
|
ModelNotSupportedException,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||||||
|
# 业务异常:简洁日志,不打印堆栈
|
||||||
|
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
||||||
|
else:
|
||||||
|
# 未知异常:完整堆栈
|
||||||
|
logger.exception(f"{message}: {error}")
|
||||||
|
|||||||
@@ -382,7 +382,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"流式请求失败: {e}")
|
self._log_request_error("流式请求失败", e)
|
||||||
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -413,20 +413,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
|
self._log_request_error("流式请求失败", e)
|
||||||
from src.core.exceptions import (
|
|
||||||
ProviderException,
|
|
||||||
QuotaExceededException,
|
|
||||||
RateLimitException,
|
|
||||||
ModelNotSupportedException,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
|
||||||
# 业务异常:简洁日志
|
|
||||||
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
|
|
||||||
else:
|
|
||||||
# 未知异常:完整堆栈
|
|
||||||
logger.exception(f"流式请求失败: {e}")
|
|
||||||
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
|
|||||||
ResponseParser,
|
ResponseParser,
|
||||||
StreamStats,
|
StreamStats,
|
||||||
)
|
)
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||||
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
result.input_tokens = usage.get("input_tokens", 0)
|
result.input_tokens = usage.get("input_tokens", 0)
|
||||||
result.output_tokens = usage.get("output_tokens", 0)
|
result.output_tokens = usage.get("output_tokens", 0)
|
||||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
|
||||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||||
|
|
||||||
# 检查错误(支持嵌套错误格式)
|
# 检查错误(支持嵌套错误格式)
|
||||||
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||||
|
# 对于 message_start 事件,usage 在 message.usage 路径下
|
||||||
|
# 对于其他响应,usage 在顶层
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
|
if not usage and "message" in response:
|
||||||
|
usage = response.get("message", {}).get("usage", {})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,14 +104,40 @@ class StreamContext:
|
|||||||
cached_tokens: Optional[int] = None,
|
cached_tokens: Optional[int] = None,
|
||||||
cache_creation_tokens: Optional[int] = None,
|
cache_creation_tokens: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新 Token 使用统计"""
|
"""
|
||||||
if input_tokens is not None:
|
更新 Token 使用统计
|
||||||
|
|
||||||
|
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
|
||||||
|
|
||||||
|
设计原理:
|
||||||
|
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
|
||||||
|
- 后续事件可能会提供完整的统计数据
|
||||||
|
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
|
||||||
|
|
||||||
|
示例场景:
|
||||||
|
- message_start 事件:input_tokens=100, output_tokens=0
|
||||||
|
- message_delta 事件:input_tokens=0, output_tokens=50
|
||||||
|
- 最终结果:input_tokens=100, output_tokens=50
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
- 此策略假设初始值为 0 是正确的默认状态
|
||||||
|
- 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tokens: 输入 tokens 数量
|
||||||
|
output_tokens: 输出 tokens 数量
|
||||||
|
cached_tokens: 缓存命中 tokens 数量
|
||||||
|
cache_creation_tokens: 缓存创建 tokens 数量
|
||||||
|
"""
|
||||||
|
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
|
||||||
self.input_tokens = input_tokens
|
self.input_tokens = input_tokens
|
||||||
if output_tokens is not None:
|
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
|
||||||
self.output_tokens = output_tokens
|
self.output_tokens = output_tokens
|
||||||
if cached_tokens is not None:
|
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
|
||||||
self.cached_tokens = cached_tokens
|
self.cached_tokens = cached_tokens
|
||||||
if cache_creation_tokens is not None:
|
if cache_creation_tokens is not None and (
|
||||||
|
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
|
||||||
|
):
|
||||||
self.cache_creation_tokens = cache_creation_tokens
|
self.cache_creation_tokens = cache_creation_tokens
|
||||||
|
|
||||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||||
|
|||||||
31
src/api/handlers/base/utils.py
Normal file
31
src/api/handlers/base/utils.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Handler 基础工具函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||||
|
"""
|
||||||
|
提取缓存创建 tokens(兼容新旧格式)
|
||||||
|
|
||||||
|
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||||
|
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||||
|
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||||
|
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||||
|
|
||||||
|
此函数自动检测并适配两种格式,优先使用新格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: API 响应中的 usage 字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存创建 tokens 总数
|
||||||
|
"""
|
||||||
|
# 优先使用新格式
|
||||||
|
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||||
|
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||||
|
total = int(cache_5m) + int(cache_1h)
|
||||||
|
|
||||||
|
# 如果新格式不存在(total == 0),回退到旧格式
|
||||||
|
return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0))
|
||||||
@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeChatHandler(ChatHandlerBase):
|
class ClaudeChatHandler(ChatHandlerBase):
|
||||||
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
result["model"] = mapped_model
|
result["model"] = mapped_model
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _convert_request(self, request):
|
async def _convert_request(self, request: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
将请求转换为 Claude 格式
|
将请求转换为 Claude 格式
|
||||||
|
|
||||||
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
Claude 格式使用:
|
Claude 格式使用:
|
||||||
- input_tokens / output_tokens
|
- input_tokens / output_tokens
|
||||||
- cache_creation_input_tokens / cache_read_input_tokens
|
- cache_creation_input_tokens / cache_read_input_tokens
|
||||||
|
- 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
|
||||||
"""
|
"""
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
|
|
||||||
input_tokens = usage.get("input_tokens", 0)
|
|
||||||
output_tokens = usage.get("output_tokens", 0)
|
|
||||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
|
||||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
|
||||||
|
|
||||||
# 处理新的 cache_creation 格式
|
|
||||||
if "cache_creation" in usage:
|
|
||||||
cache_creation_data = usage.get("cache_creation", {})
|
|
||||||
if not cache_creation_input_tokens:
|
|
||||||
cache_creation_input_tokens = cache_creation_data.get(
|
|
||||||
"ephemeral_5m_input_tokens", 0
|
|
||||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": output_tokens,
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_input_tokens": cache_read_input_tokens,
|
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _normalize_response(self, response: Dict) -> Dict:
|
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
规范化 Claude 响应
|
规范化 Claude 响应
|
||||||
|
|
||||||
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||||
return self.response_normalizer.normalize_claude_response(
|
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||||
response_data=response,
|
response_data=response,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeStreamParser:
|
class ClaudeStreamParser:
|
||||||
"""
|
"""
|
||||||
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
|
|||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
|
|||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
|
|||||||
CliMessageHandlerBase,
|
CliMessageHandlerBase,
|
||||||
StreamContext,
|
StreamContext,
|
||||||
)
|
)
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||||
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
|||||||
usage = message.get("usage", {})
|
usage = message.get("usage", {})
|
||||||
if usage:
|
if usage:
|
||||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||||
# Claude 的缓存 tokens 使用不同的字段名
|
|
||||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||||
if cache_read:
|
if cache_read:
|
||||||
ctx.cached_tokens = cache_read
|
ctx.cached_tokens = cache_read
|
||||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
|
||||||
|
cache_creation = extract_cache_creation_tokens(usage)
|
||||||
if cache_creation:
|
if cache_creation:
|
||||||
ctx.cache_creation_tokens = cache_creation
|
ctx.cache_creation_tokens = cache_creation
|
||||||
|
|
||||||
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
|||||||
ctx.input_tokens = usage["input_tokens"]
|
ctx.input_tokens = usage["input_tokens"]
|
||||||
if "output_tokens" in usage:
|
if "output_tokens" in usage:
|
||||||
ctx.output_tokens = usage["output_tokens"]
|
ctx.output_tokens = usage["output_tokens"]
|
||||||
# 更新缓存 tokens
|
|
||||||
|
# 更新缓存读取 tokens
|
||||||
if "cache_read_input_tokens" in usage:
|
if "cache_read_input_tokens" in usage:
|
||||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||||
if "cache_creation_input_tokens" in usage:
|
|
||||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
# 更新缓存创建 tokens
|
||||||
|
cache_creation = extract_cache_creation_tokens(usage)
|
||||||
|
if cache_creation > 0:
|
||||||
|
ctx.cache_creation_tokens = cache_creation
|
||||||
|
|
||||||
# 检查是否结束
|
# 检查是否结束
|
||||||
delta = data.get("delta", {})
|
delta = data.get("delta", {})
|
||||||
|
|||||||
@@ -120,6 +120,33 @@ class CacheService:
|
|||||||
logger.warning(f"缓存检查失败: {key} - {e}")
|
logger.warning(f"缓存检查失败: {key} - {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
|
||||||
|
"""
|
||||||
|
递增缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
ttl_seconds: 可选,如果提供则刷新 TTL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
递增后的值,如果失败返回 0
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_client(require_redis=False)
|
||||||
|
if not redis:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
result = await redis.incr(key)
|
||||||
|
# 如果提供了 TTL,刷新过期时间
|
||||||
|
if ttl_seconds is not None:
|
||||||
|
await redis.expire(key, ttl_seconds)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"缓存递增失败: {key} - {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# 缓存键前缀
|
# 缓存键前缀
|
||||||
class CacheKeys:
|
class CacheKeys:
|
||||||
|
|||||||
157
src/services/cache/model_cache.py
vendored
157
src/services/cache/model_cache.py
vendored
@@ -2,11 +2,9 @@
|
|||||||
Model 映射缓存服务 - 减少模型查询
|
Model 映射缓存服务 - 减少模型查询
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.config.constants import CacheTTL
|
from src.config.constants import CacheTTL
|
||||||
@@ -106,6 +104,7 @@ class ModelCacheService:
|
|||||||
Model 对象或 None
|
Model 对象或 None
|
||||||
"""
|
"""
|
||||||
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
|
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
|
||||||
|
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
# 1. 尝试从缓存获取
|
||||||
cached_data = await CacheService.get(cache_key)
|
cached_data = await CacheService.get(cache_key)
|
||||||
@@ -113,6 +112,8 @@ class ModelCacheService:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||||
)
|
)
|
||||||
|
# 递增命中计数,同时刷新 TTL
|
||||||
|
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||||
return ModelCacheService._dict_to_model(cached_data)
|
return ModelCacheService._dict_to_model(cached_data)
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
# 2. 缓存未命中,查询数据库
|
||||||
@@ -130,6 +131,8 @@ class ModelCacheService:
|
|||||||
if model:
|
if model:
|
||||||
model_dict = ModelCacheService._model_to_dict(model)
|
model_dict = ModelCacheService._model_to_dict(model)
|
||||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||||
|
# 重置命中计数(新缓存从1开始)
|
||||||
|
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||||
)
|
)
|
||||||
@@ -189,9 +192,10 @@ class ModelCacheService:
|
|||||||
# 清除 model:id 缓存
|
# 清除 model:id 缓存
|
||||||
await CacheService.delete(f"model:id:{model_id}")
|
await CacheService.delete(f"model:id:{model_id}")
|
||||||
|
|
||||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
# 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
|
||||||
if provider_id and global_model_id:
|
if provider_id and global_model_id:
|
||||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||||
|
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
||||||
)
|
)
|
||||||
@@ -230,16 +234,20 @@ class ModelCacheService:
|
|||||||
db: Session, model_name: str
|
db: Session, model_name: str
|
||||||
) -> Optional[GlobalModel]:
|
) -> Optional[GlobalModel]:
|
||||||
"""
|
"""
|
||||||
通过名称或映射解析 GlobalModel(带缓存,支持映射匹配)
|
通过名称解析 GlobalModel(带缓存)
|
||||||
|
|
||||||
查找顺序:
|
查找顺序:
|
||||||
1. 检查缓存
|
1. 检查缓存
|
||||||
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases)
|
2. 通过 provider_model_name 匹配(查询 Model 表)
|
||||||
3. 直接匹配 GlobalModel.name(兜底)
|
3. 直接匹配 GlobalModel.name(兜底)
|
||||||
|
|
||||||
|
注意:此方法不使用 provider_model_aliases 进行全局解析。
|
||||||
|
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
|
||||||
|
由 resolve_provider_model() 处理。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GlobalModel 对象或 None
|
GlobalModel 对象或 None
|
||||||
@@ -273,116 +281,53 @@ class ModelCacheService:
|
|||||||
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
||||||
return ModelCacheService._dict_to_global_model(cached_data)
|
return ModelCacheService._dict_to_global_model(cached_data)
|
||||||
|
|
||||||
# 2. 优先通过 provider_model_name 和映射名称匹配(Provider 配置优先级最高)
|
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases)
|
||||||
from sqlalchemy import or_
|
# 重要:provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
|
||||||
|
# 全局解析不应该受到某个 Provider 别名配置的影响
|
||||||
|
# 例如:Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
|
||||||
from src.models.database import Provider
|
from src.models.database import Provider
|
||||||
|
|
||||||
# 构建精确的映射匹配条件
|
models_with_global = (
|
||||||
# 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
|
db.query(Model, GlobalModel)
|
||||||
# 对于 SQLite,会在 Python 层面进行过滤
|
.join(Provider, Model.provider_id == Provider.id)
|
||||||
try:
|
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||||
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
|
.filter(
|
||||||
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
|
Provider.is_active == True,
|
||||||
jsonb_pattern = json.dumps([{"name": normalized_name}])
|
Model.is_active == True,
|
||||||
models_with_global = (
|
GlobalModel.is_active == True,
|
||||||
db.query(Model, GlobalModel)
|
Model.provider_model_name == normalized_name,
|
||||||
.join(Provider, Model.provider_id == Provider.id)
|
|
||||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
|
||||||
.filter(
|
|
||||||
Provider.is_active == True,
|
|
||||||
Model.is_active == True,
|
|
||||||
GlobalModel.is_active == True,
|
|
||||||
or_(
|
|
||||||
Model.provider_model_name == normalized_name,
|
|
||||||
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
|
|
||||||
Model.provider_model_aliases.op("@>")(jsonb_pattern),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
except (OperationalError, ProgrammingError) as e:
|
|
||||||
# JSONB 操作符不支持(如 SQLite),回退到加载匹配 provider_model_name 的 Model
|
|
||||||
# 并在 Python 层过滤 aliases
|
|
||||||
logger.debug(
|
|
||||||
f"JSONB 查询失败,回退到 Python 过滤: {e}",
|
|
||||||
)
|
|
||||||
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
|
|
||||||
models_with_global = (
|
|
||||||
db.query(Model, GlobalModel)
|
|
||||||
.join(Provider, Model.provider_id == Provider.id)
|
|
||||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
|
||||||
.filter(
|
|
||||||
Provider.is_active == True,
|
|
||||||
Model.is_active == True,
|
|
||||||
GlobalModel.is_active == True,
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)}
|
# 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
|
||||||
# 使用字典去重,同一个 Model 只保留优先级最高的匹配
|
matched_global_models: List[GlobalModel] = []
|
||||||
matched_models_dict = {}
|
seen_global_model_ids: set[str] = set()
|
||||||
|
|
||||||
# 遍历查询结果进行匹配
|
|
||||||
for model, gm in models_with_global:
|
for model, gm in models_with_global:
|
||||||
key = (model.id, gm.id)
|
if gm.id not in seen_global_model_ids:
|
||||||
|
seen_global_model_ids.add(gm.id)
|
||||||
# 检查 provider_model_aliases 是否匹配(优先级更高)
|
matched_global_models.append(gm)
|
||||||
if model.provider_model_aliases:
|
logger.debug(
|
||||||
for alias_entry in model.provider_model_aliases:
|
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||||
if isinstance(alias_entry, dict):
|
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||||
alias_name = alias_entry.get("name", "").strip()
|
|
||||||
if alias_name == normalized_name:
|
|
||||||
# alias 优先级为 0(最高),覆盖任何已存在的匹配
|
|
||||||
matched_models_dict[key] = (gm, "alias", 0)
|
|
||||||
logger.debug(
|
|
||||||
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
|
|
||||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
|
|
||||||
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
|
|
||||||
if model.provider_model_name == normalized_name:
|
|
||||||
# provider_model_name 优先级为 1(兜底),只在没有 alias 匹配时使用
|
|
||||||
if key not in matched_models_dict:
|
|
||||||
matched_models_dict[key] = (gm, "provider_model_name", 1)
|
|
||||||
logger.debug(
|
|
||||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
|
||||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果通过 provider_model_name/alias 找到了,直接返回
|
|
||||||
if matched_models_dict:
|
|
||||||
# 转换为列表并排序:按 priority(alias=0 优先)、然后按 GlobalModel.name
|
|
||||||
matched_global_models = [
|
|
||||||
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
|
|
||||||
]
|
|
||||||
matched_global_models.sort(
|
|
||||||
key=lambda item: (
|
|
||||||
0 if item[1] == "alias" else 1, # alias 优先
|
|
||||||
item[0].name # 同优先级按名称排序(确定性)
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# 记录解析方式
|
# 如果通过 provider_model_name 找到了,返回
|
||||||
resolution_method = matched_global_models[0][1]
|
if matched_global_models:
|
||||||
|
resolution_method = "provider_model_name"
|
||||||
|
|
||||||
if len(matched_global_models) > 1:
|
if len(matched_global_models) > 1:
|
||||||
# 检测到冲突
|
# 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name)
|
||||||
unique_models = {gm.id: gm for gm, _ in matched_global_models}
|
model_names = [gm.name for gm in matched_global_models if gm.name]
|
||||||
if len(unique_models) > 1:
|
logger.warning(
|
||||||
model_names = [gm.name for gm in unique_models.values()]
|
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||||
logger.warning(
|
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||||
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
)
|
||||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
# 记录冲突指标
|
||||||
)
|
model_mapping_conflict_total.inc()
|
||||||
# 记录冲突指标
|
|
||||||
model_mapping_conflict_total.inc()
|
|
||||||
|
|
||||||
# 返回第一个匹配的 GlobalModel
|
# 返回第一个匹配的 GlobalModel
|
||||||
result_global_model: GlobalModel = matched_global_models[0][0]
|
result_global_model = matched_global_models[0]
|
||||||
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
||||||
await CacheService.set(
|
await CacheService.set(
|
||||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
- 根据 API 格式或端点配置生成请求 URL
|
- 根据 API 格式或端点配置生成请求 URL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
|
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
|
||||||
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
|
|||||||
from src.core.enums import APIFormat
|
from src.core.enums import APIFormat
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_provider_headers(
|
def build_provider_headers(
|
||||||
endpoint,
|
endpoint: "ProviderEndpoint",
|
||||||
key,
|
key: "ProviderAPIKey",
|
||||||
original_headers: Optional[Dict[str, str]] = None,
|
original_headers: Optional[Dict[str, str]] = None,
|
||||||
*,
|
*,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
@@ -28,7 +31,8 @@ def build_provider_headers(
|
|||||||
"""
|
"""
|
||||||
headers: Dict[str, str] = {}
|
headers: Dict[str, str] = {}
|
||||||
|
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
# api_key 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||||
|
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
|
||||||
|
|
||||||
# 根据 API 格式自动选择认证头
|
# 根据 API 格式自动选择认证头
|
||||||
api_format = getattr(endpoint, "api_format", None)
|
api_format = getattr(endpoint, "api_format", None)
|
||||||
@@ -68,8 +72,32 @@ def build_provider_headers(
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_base_url(base_url: str, path: str) -> str:
|
||||||
|
"""
|
||||||
|
规范化 base_url,去除末尾的斜杠和可能与 path 重复的版本前缀。
|
||||||
|
|
||||||
|
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
|
||||||
|
避免拼接出 /v1/v1/messages 这样的重复路径。
|
||||||
|
|
||||||
|
兼容用户填写的各种格式:
|
||||||
|
- https://api.example.com
|
||||||
|
- https://api.example.com/
|
||||||
|
- https://api.example.com/v1
|
||||||
|
- https://api.example.com/v1/
|
||||||
|
"""
|
||||||
|
base = base_url.rstrip("/")
|
||||||
|
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
|
||||||
|
# 例如:base="/v1", path="/v1/messages" -> 去除 /v1
|
||||||
|
# 例如:base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
|
||||||
|
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
|
||||||
|
if base.endswith(suffix) and path.startswith(suffix):
|
||||||
|
base = base[: -len(suffix)]
|
||||||
|
break
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
def build_provider_url(
|
def build_provider_url(
|
||||||
endpoint,
|
endpoint: "ProviderEndpoint",
|
||||||
*,
|
*,
|
||||||
query_params: Optional[Dict[str, Any]] = None,
|
query_params: Optional[Dict[str, Any]] = None,
|
||||||
path_params: Optional[Dict[str, Any]] = None,
|
path_params: Optional[Dict[str, Any]] = None,
|
||||||
@@ -88,8 +116,6 @@ def build_provider_url(
|
|||||||
path_params: 路径模板参数 (如 {model})
|
path_params: 路径模板参数 (如 {model})
|
||||||
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
|
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
|
||||||
"""
|
"""
|
||||||
base = endpoint.base_url.rstrip("/")
|
|
||||||
|
|
||||||
# 准备路径参数,添加 Gemini API 所需的 action 参数
|
# 准备路径参数,添加 Gemini API 所需的 action 参数
|
||||||
effective_path_params = dict(path_params) if path_params else {}
|
effective_path_params = dict(path_params) if path_params else {}
|
||||||
|
|
||||||
@@ -123,6 +149,9 @@ def build_provider_url(
|
|||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
path = f"/{path}"
|
path = f"/{path}"
|
||||||
|
|
||||||
|
# 先确定 path,再根据 path 规范化 base_url
|
||||||
|
# base_url 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||||
|
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
|
||||||
url = f"{base}{path}"
|
url = f"{base}{path}"
|
||||||
|
|
||||||
# 添加查询参数
|
# 添加查询参数
|
||||||
@@ -134,7 +163,7 @@ def build_provider_url(
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def _resolve_default_path(api_format) -> str:
|
def _resolve_default_path(api_format: Optional[str]) -> str:
|
||||||
"""
|
"""
|
||||||
根据 API 格式返回默认路径
|
根据 API 格式返回默认路径
|
||||||
"""
|
"""
|
||||||
|
|||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""测试模块"""
|
||||||
90
tests/api/handlers/base/test_utils.py
Normal file
90
tests/api/handlers/base/test_utils.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""测试 handler 基础工具函数"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCacheCreationTokens:
|
||||||
|
"""测试 extract_cache_creation_tokens 函数"""
|
||||||
|
|
||||||
|
def test_new_format_only(self) -> None:
|
||||||
|
"""测试只有新格式字段"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
|
def test_new_format_5m_only(self) -> None:
|
||||||
|
"""测试只有 5 分钟缓存"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 150,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 150
|
||||||
|
|
||||||
|
def test_new_format_1h_only(self) -> None:
|
||||||
|
"""测试只有 1 小时缓存"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 250,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 250
|
||||||
|
|
||||||
|
def test_old_format_only(self) -> None:
|
||||||
|
"""测试只有旧格式字段"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 500,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 500
|
||||||
|
|
||||||
|
def test_both_formats_prefers_new(self) -> None:
|
||||||
|
"""测试同时存在时优先使用新格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
|
def test_empty_usage(self) -> None:
|
||||||
|
"""测试空字典"""
|
||||||
|
usage = {}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_all_zeros(self) -> None:
|
||||||
|
"""测试所有字段都为 0"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 0,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||||
|
"""测试新格式字段不存在时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 123,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 123
|
||||||
|
|
||||||
|
def test_new_format_zero_fallback_to_old(self) -> None:
|
||||||
|
"""测试新格式为 0 时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 456,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 456
|
||||||
|
|
||||||
|
def test_unrelated_fields_ignored(self) -> None:
|
||||||
|
"""测试忽略无关字段"""
|
||||||
|
usage = {
|
||||||
|
"input_tokens": 1000,
|
||||||
|
"output_tokens": 2000,
|
||||||
|
"cache_read_input_tokens": 300,
|
||||||
|
"claude_cache_creation_5_m_tokens": 50,
|
||||||
|
"claude_cache_creation_1_h_tokens": 75,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 125
|
||||||
Reference in New Issue
Block a user