perf: 添加多层缓存优化减少数据库查询

- 新增 ProviderCacheService 缓存 Provider 和 ProviderAPIKey 数据
- SystemConfigService 添加进程内缓存(TTL 60秒)
- API Key last_used_at 更新添加节流策略(60秒间隔)
- HTTP 连接池配置改为可配置,支持根据 Worker 数量自动计算
- 前端优先级管理改用 health_score 显示健康度
This commit is contained in:
fawney19
2026-01-08 02:34:59 +08:00
parent d9e6346911
commit d378630b38
9 changed files with 374 additions and 31 deletions

View File

@@ -262,17 +262,17 @@
<div class="shrink-0 flex items-center gap-3"> <div class="shrink-0 flex items-center gap-3">
<!-- 健康度 --> <!-- 健康度 -->
<div <div
v-if="key.success_rate !== null" v-if="key.health_score != null"
class="text-xs text-right" class="text-xs text-right"
> >
<div <div
class="font-medium tabular-nums" class="font-medium tabular-nums"
:class="[ :class="[
key.success_rate >= 0.95 ? 'text-green-600' : key.health_score >= 0.95 ? 'text-green-600' :
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500' key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
]" ]"
> >
{{ (key.success_rate * 100).toFixed(0) }}% {{ ((key.health_score || 0) * 100).toFixed(0) }}%
</div> </div>
<div class="text-[10px] text-muted-foreground opacity-70"> <div class="text-[10px] text-muted-foreground opacity-70">
{{ key.request_count }} reqs {{ key.request_count }} reqs
@@ -400,6 +400,7 @@ interface KeyWithMeta {
endpoint_base_url: string endpoint_base_url: string
api_format: string api_format: string
capabilities: string[] capabilities: string[]
health_score: number | null
success_rate: number | null success_rate: number | null
avg_response_time_ms: number | null avg_response_time_ms: number | null
request_count: number request_count: number

View File

@@ -18,6 +18,7 @@ from src.core.key_capabilities import get_capability
from src.core.logger import logger from src.core.logger import logger
from src.database import get_db from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import ( from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest, BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate, EndpointAPIKeyCreate,
@@ -411,6 +412,10 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit() db.commit()
db.refresh(key) db.refresh(key)
# 如果更新了 rate_multiplier清除缓存
if "rate_multiplier" in update_data:
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys())) logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try: try:
@@ -550,6 +555,7 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
"endpoint_base_url": endpoint.base_url, "endpoint_base_url": endpoint.base_url,
"api_format": api_format, "api_format": api_format,
"capabilities": caps_list, "capabilities": caps_list,
"health_score": key.health_score,
"success_rate": success_rate, "success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms, "avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count, "request_count": key.request_count,

View File

@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider from src.models.database import Provider
from src.services.cache.provider_cache import ProviderCacheService
router = APIRouter(tags=["Provider CRUD"]) router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@@ -296,6 +298,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
db.commit() db.commit()
db.refresh(provider) db.refresh(provider)
# 如果更新了 billing_type清除缓存
if "billing_type" in update_data:
await ProviderCacheService.invalidate_provider_cache(provider.id)
logger.debug(f"已清除 Provider 缓存: {provider.id}")
context.add_audit_metadata( context.add_audit_metadata(
action="update_provider", action="update_provider",
provider_id=provider.id, provider_id=provider.id,

View File

@@ -90,13 +90,18 @@ class HTTPClientPool:
pool=config.http_pool_timeout, pool=config.http_pool_timeout,
), ),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=100, # 最大连接数 max_connections=config.http_max_connections,
max_keepalive_connections=20, # 最大保活连接数 max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=30.0, # 保活过期时间(秒) keepalive_expiry=config.http_keepalive_expiry,
), ),
follow_redirects=True, # 跟随重定向 follow_redirects=True, # 跟随重定向
) )
logger.info("全局HTTP客户端池已初始化") logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client return cls._default_client
@classmethod @classmethod

View File

@@ -145,6 +145,24 @@ class Config:
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# HTTP 连接池配置
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
# - 每个连接占用一个 socket过多会耗尽系统资源
# - 默认根据 Worker 数量自动计算:单 Worker 200多 Worker 按比例分配
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
# - 高频请求场景应该增大此值
# - 默认为 max_connections 的 30%(长连接场景更高效)
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
# - 过短会频繁重建连接,过长会占用资源
# - 默认 30 秒,生图等长连接场景可适当增大
self.http_max_connections = int(
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
)
self.http_keepalive_connections = int(
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
)
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
# 流式处理配置 # 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误 # STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
@@ -224,6 +242,53 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同""" """智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size return self.db_pool_size
def _auto_http_max_connections(self) -> int:
"""
智能计算 HTTP 最大连接数
计算依据:
1. 系统 socket 资源有限Linux 默认 ulimit -n 通常为 1024
2. 多 Worker 部署时每个进程独立连接池
3. 需要为数据库连接、Redis 连接等预留资源
公式: base_connections / workers
- 单 Worker: 200 连接(适合开发/低负载)
- 多 Worker: 按比例分配,确保总数不超过系统限制
范围: 50 - 500
"""
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
# (预留给 DB、Redis、内部服务等
base_connections = 800
workers = max(self.worker_processes, 1)
# 每个 Worker 分配的连接数
per_worker = base_connections // workers
# 限制范围:最小 50保证基本并发最大 500避免资源耗尽
return max(50, min(per_worker, 500))
def _auto_http_keepalive_connections(self) -> int:
"""
智能计算 HTTP 保活连接数
计算依据:
1. 保活连接用于复用,减少 TCP 握手开销
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
3. 生图等长连接场景,连接会被长时间占用
公式: max_connections * 0.3
- 30% 的比例在复用效率和资源占用间取得平衡
- 长连接场景建议手动调高到 50-70%
范围: 10 - max_connections
"""
# 保活连接数为最大连接数的 30%
keepalive = int(self.http_max_connections * 0.3)
# 最小 10 个保活连接,最大不超过 max_connections
return max(10, min(keepalive, self.http_max_connections))
def _parse_ttfb_timeout(self) -> float: def _parse_ttfb_timeout(self) -> float:
""" """
解析 TTFB 超时配置,带错误处理和范围限制 解析 TTFB 超时配置,带错误处理和范围限制

View File

@@ -8,7 +8,9 @@ import hashlib
import secrets import secrets
import time import time
import uuid import uuid
from collections import OrderedDict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import jwt import jwt
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService from src.services.user.apikey import ApiKeyService
# API Key last_used_at 更新节流配置
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
# 使用 OrderedDict 实现 LRU避免内存无限增长
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
_last_update_lock = Lock()
def _should_update_last_used(api_key_id: str) -> bool:
"""判断是否应该更新 API Key 的 last_used_at
使用节流策略,同一个 Key 在指定间隔内只更新一次。
线程安全,使用 LRU 策略限制缓存大小。
Returns:
True 表示应该更新False 表示跳过
"""
now = time.time()
with _last_update_lock:
last_update = _api_key_last_update_times.get(api_key_id, 0)
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
_api_key_last_update_times[api_key_id] = now
# LRU: 移到末尾(最近使用)
_api_key_last_update_times.move_to_end(api_key_id)
# 超过最大容量时,移除最旧的条目
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
_api_key_last_update_times.popitem(last=False)
return True
return False
# JWT配置从config读取 # JWT配置从config读取
if not config.jwt_secret_key: if not config.jwt_secret_key:
# 如果没有配置,生成一个随机密钥并警告 # 如果没有配置,生成一个随机密钥并警告
@@ -367,9 +407,10 @@ class AuthService:
logger.warning(f"API认证失败 - 用户已禁用: {user.email}") logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
return None return None
# 更新最后使用时间 # 更新最后使用时间(使用节流策略,减少数据库写入)
key_record.last_used_at = datetime.now(timezone.utc) if _should_update_last_used(key_record.id):
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求 key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12] api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp) logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)

171
src/services/cache/provider_cache.py vendored Normal file
View File

@@ -0,0 +1,171 @@
"""
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
"""
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey
class ProviderCacheService:
"""Provider 缓存服务
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
主要用于 UsageService 中获取费率倍数和计费类型。
"""
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str
) -> Optional[float]:
"""
获取 ProviderAPIKey 的 rate_multiplier带缓存
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
Returns:
rate_multiplier 或 None如果找不到
"""
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
return float(cached_data)
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 写入缓存
if provider_key:
rate_multiplier = provider_key.rate_multiplier or 1.0
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
return rate_multiplier
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_provider_billing_type(
db: Session, provider_id: str
) -> Optional[ProviderBillingType]:
"""
获取 Provider 的 billing_type带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
billing_type 或 None如果找不到
"""
cache_key = f"provider:billing_type:{provider_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
if cached_data == "NOT_FOUND":
return None
try:
return ProviderBillingType(cached_data)
except ValueError:
# 缓存值无效,删除并重新查询
await CacheService.delete(cache_key)
# 2. 缓存未命中,查询数据库
provider = (
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
)
# 3. 写入缓存
if provider:
billing_type = provider.billing_type
await CacheService.set(
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
return billing_type
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_rate_multiplier_and_free_tier(
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐(带缓存)
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
Returns:
(rate_multiplier, is_free_tier) 元组
"""
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
# 获取计费类型
if provider_id:
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
if billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存"""
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod
async def invalidate_provider_cache(provider_id: str) -> None:
"""清除 Provider 缓存"""
await CacheService.delete(f"provider:billing_type:{provider_id}")
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")

View File

@@ -3,8 +3,9 @@
""" """
import json import json
import time
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -20,6 +21,49 @@ class LogLevel(str, Enum):
FULL = "full" # 记录完整请求和响应包含body敏感信息会脱敏 FULL = "full" # 记录完整请求和响应包含body敏感信息会脱敏
# 进程内缓存 TTL- 系统配置变化不频繁,使用较长的 TTL
_CONFIG_CACHE_TTL = 60 # 1 分钟
# 进程内缓存存储: {key: (value, expire_time)}
_config_cache: Dict[str, Tuple[Any, float]] = {}
def _get_cached_config(key: str) -> Tuple[bool, Any]:
"""从进程内缓存获取配置值
Returns:
(hit, value): hit=True 表示缓存命中value 为缓存的值
"""
if key in _config_cache:
value, expire_time = _config_cache[key]
if time.time() < expire_time:
return True, value
# 缓存过期,安全删除(避免并发时 KeyError
_config_cache.pop(key, None)
return False, None
def _set_cached_config(key: str, value: Any) -> None:
"""设置进程内缓存"""
_config_cache[key] = (value, time.time() + _CONFIG_CACHE_TTL)
def invalidate_config_cache(key: Optional[str] = None) -> None:
"""清除配置缓存
Args:
key: 配置键,如果为 None 则清除所有缓存
"""
global _config_cache
if key is None:
_config_cache = {}
logger.debug("已清除所有系统配置缓存")
else:
# 使用 pop 安全删除,避免并发时 KeyError
if _config_cache.pop(key, None) is not None:
logger.debug(f"已清除系统配置缓存: {key}")
class SystemConfigService: class SystemConfigService:
"""系统配置服务类""" """系统配置服务类"""
@@ -127,14 +171,23 @@ class SystemConfigService:
@classmethod @classmethod
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]: def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
"""获取系统配置值""" """获取系统配置值(带进程内缓存)"""
# 1. 检查进程内缓存
hit, cached_value = _get_cached_config(key)
if hit:
return cached_value
# 2. 查询数据库
config = db.query(SystemConfig).filter(SystemConfig.key == key).first() config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config: if config:
_set_cached_config(key, config.value)
return config.value return config.value
# 如果配置不存在,检查默认值 # 3. 如果配置不存在,使用默认值
if key in cls.DEFAULT_CONFIGS: if key in cls.DEFAULT_CONFIGS:
return cls.DEFAULT_CONFIGS[key]["value"] value = cls.DEFAULT_CONFIGS[key]["value"]
_set_cached_config(key, value)
return value
return default return default
@@ -185,6 +238,9 @@ class SystemConfigService:
db.commit() db.commit()
db.refresh(config) db.refresh(config)
# 清除缓存
invalidate_config_cache(key)
return config return config
@staticmethod @staticmethod
@@ -243,6 +299,8 @@ class SystemConfigService:
if config: if config:
db.delete(config) db.delete(config)
db.commit() db.commit()
# 清除缓存
invalidate_config_cache(key)
return True return True
return False return False

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.core.enums import ProviderBillingType
from src.core.logger import logger from src.core.logger import logger
from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole
from src.services.model.cost import ModelCostService from src.services.model.cost import ModelCostService
@@ -362,22 +361,12 @@ class UsageService:
provider_api_key_id: Optional[str], provider_api_key_id: Optional[str],
provider_id: Optional[str], provider_id: Optional[str],
) -> Tuple[float, bool]: ) -> Tuple[float, bool]:
"""获取费率倍数和是否免费套餐""" """获取费率倍数和是否免费套餐(使用缓存)"""
actual_rate_multiplier = 1.0 from src.services.cache.provider_cache import ProviderCacheService
if provider_api_key_id:
provider_key = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first()
)
if provider_key and provider_key.rate_multiplier:
actual_rate_multiplier = provider_key.rate_multiplier
is_free_tier = False return await ProviderCacheService.get_rate_multiplier_and_free_tier(
if provider_id: db, provider_api_key_id, provider_id
provider_obj = db.query(Provider).filter(Provider.id == provider_id).first() )
if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@classmethod @classmethod
async def _calculate_costs( async def _calculate_costs(