38 Commits

Author SHA1 Message Date
fawney19
6aa1876955 feat: add Dockerfile.app.local for China mirror support 2025-12-19 16:20:02 +08:00
fawney19
7f07122aea refactor: separate frontend build from base image for faster incremental builds 2025-12-19 16:02:38 +08:00
fawney19
c2ddc6bd3c refactor: optimize Docker build with multi-stage and slim runtime base image 2025-12-19 15:51:21 +08:00
fawney19
af476ff21e feat: enhance error logging and upstream response tracking for provider failures 2025-12-19 15:29:48 +08:00
fawney19
3bbc1c6b66 feat: add provider compatibility error detection for intelligent failover
- Introduce ProviderCompatibilityException for unsupported parameter/feature errors
- Add COMPATIBILITY_ERROR_PATTERNS to detect provider-specific limitations
- Implement _is_compatibility_error() method in ErrorClassifier
- Prioritize compatibility error checking before client error validation
- Remove 'max_tokens' from CLIENT_ERROR_PATTERNS as it can indicate compatibility issues
- Enable automatic failover when provider doesn't support requested features
- Improve error classification accuracy with pattern matching for common compatibility issues
2025-12-19 13:28:26 +08:00
fawney19
c69a0a8506 refactor: remove stream smoothing config from system settings and improve base image caching
- Remove stream_smoothing configuration from SystemConfigService (moved to handler default)
- Remove stream smoothing UI controls from admin settings page
- Add AdminClearSingleAffinityAdapter for targeted cache invalidation
- Add clearSingleAffinity API endpoint to clear specific affinity cache entries
- Include global_model_id in affinity list response for UI deletion support
- Improve CI/CD workflow with hash-based base image change detection
- Add hash label to base image for reliable cache invalidation detection
- Use remote image inspection to determine if base image rebuild is needed
- Include Dockerfile.base in hash calculation for proper dependency tracking
2025-12-19 13:09:56 +08:00
fawney19
1fae202bde Merge pull request #30 from AAEE86/master
chore: Modify the order of API format enumeration
2025-12-19 12:34:22 +08:00
fawney19
b9a26c4550 fix: add SETUPTOOLS_SCM_PRETEND_VERSION for CI builds 2025-12-19 12:01:19 +08:00
AAEE86
e42bd35d48 chore: Modify the order of API format enumeration - Move CLAUDE_CLI before OPENAI 2025-12-19 11:44:10 +08:00
fawney19
f22a073fd9 fix: rebuild app image when base image changes during deployment
- Track BASE_REBUILT flag to detect base image rebuilds
- Force app image rebuild when base image is rebuilt
- Prevents stale app images built with outdated base images
- Ensures consistent deployment when base dependencies change
2025-12-19 11:32:43 +08:00
fawney19
5c7ad089d2 fix: disable nginx buffering for streaming responses
- Add X-Accel-Buffering: no header to prevent nginx from buffering streamed content
- Ensures immediate delivery of each chunk without proxy buffering delays
- Improves real-time streaming performance and responsiveness
- Applies to both production and local Dockerfiles
2025-12-19 11:26:15 +08:00
fawney19
97425ac68f refactor: make stream smoothing parameters configurable and add models cache invalidation
- Move stream smoothing parameters (chunk_size, delay_ms) to database config
- Remove hardcoded stream smoothing constants from StreamProcessor
- Simplify dynamic delay calculation by using config values directly
- Add invalidate_models_list_cache() function to clear /v1/models endpoint cache
- Call cache invalidation on model create, update, delete, and bulk operations
- Update admin UI to allow runtime configuration of smoothing parameters
- Improve model listing freshness when models are modified
2025-12-19 11:03:46 +08:00
fawney19
912f6643e2 tune: adjust stream smoothing parameters for better user experience
- Increase chunk size from 5 to 20 characters for fewer delays
- Reduce min delay from 15ms to 8ms for faster playback
- Reduce max delay from 24ms to 15ms for better responsiveness
- Adjust text thresholds to better differentiate content types
- Apply parameter tuning to both StreamProcessor and _LightweightSmoother
2025-12-19 09:51:09 +08:00
fawney19
6c0373fda6 refactor: simplify text splitting logic in stream processor
- Remove complex conditional logic for short/medium/long text differentiation
- Unify text splitting to always use consistent CHUNK_SIZE-based splitting
- Rely on dynamic delay calculation for output speed adjustment
- Reduce code complexity in both main smoother and lightweight smoother
2025-12-19 09:48:11 +08:00
fawney19
070121717d refactor: consolidate stream smoothing into StreamProcessor with intelligent timing
- Move StreamSmoother functionality directly into StreamProcessor for better integration
- Create ContentExtractor strategy pattern for format-agnostic content extraction
- Implement intelligent dynamic delay calculation based on text length
- Support three text length tiers: short (char-by-char), medium (chunked), long (chunked)
- Remove manual chunk_size and delay_ms configuration - now auto-calculated
- Simplify admin UI to single toggle switch with auto timing adjustment
- Extract format detection logic to reusable content_extractors module
- Improve code maintainability with cleaner architecture
2025-12-19 09:46:22 +08:00
fawney19
85fafeacb8 feat: add stream smoothing feature for improved user experience
- Implement StreamSmoother class to split large content chunks into smaller pieces with delay
- Support OpenAI, Claude, and Gemini API response formats for smooth playback
- Add stream smoothing configuration to system settings (enable, chunk size, delay)
- Create streamlined API for stream smoothing with StreamSmoothingConfig dataclass
- Add admin UI controls for configuring stream smoothing parameters
- Use batch configuration loading to minimize database queries
- Enable typing effect simulation for better user experience in streaming responses
2025-12-19 03:15:19 +08:00
fawney19
daf8b870f0 fix: include Dockerfile.base.local in dependency hash calculation
- Add Dockerfile.base.local to deps hash to detect Docker configuration changes
- Ensures deployment rebuilds when nginx proxy settings are modified
- Prevents stale Docker images from being reused after config changes
2025-12-19 02:38:46 +08:00
fawney19
880fb61c66 fix: disable gzip compression in nginx proxy configuration
- Add gzip off directive to prevent nginx from compressing proxied responses
- Ensures stream integrity for chunked transfer encoding
- Applies to both production and local Dockerfiles
2025-12-19 02:17:07 +08:00
fawney19
7e792dabfc refactor: use background task for client disconnection monitoring
- Replace time-based throttling with background task for disconnect checks
- Remove time.monotonic() and related throttling logic
- Prevent blocking of stream transmission during connection checks
- Properly clean up background task with try/finally block
- Improve throughput and responsiveness of stream processing
2025-12-19 01:59:56 +08:00
fawney19
cd06169b2f fix: detect OpenAI format stream completion via finish_reason
- Add detection of finish_reason in OpenAI API responses to mark stream completion
- Ensures OpenAI API streams are properly marked as complete even without explicit completion events
- Complements existing completion event detection for other API formats
2025-12-19 01:44:35 +08:00
fawney19
50ffd47546 fix: handle client disconnection after stream completion gracefully
- Check has_completion flag before marking client disconnection as failure
- Allow graceful termination if response already completed when client disconnects
- Change logging level to info for post-completion disconnections
- Prevent false error reporting when client closes connection after receiving full response
2025-12-19 01:36:20 +08:00
fawney19
5f0c1fb347 refactor: remove unused response normalizer module
- Delete unused ResponseNormalizer class and its initialization logic
- Remove response_normalizer and enable_response_normalization parameters from handlers
- Simplify chat adapter base initialization by removing normalizer setup
- Clean up unused imports in handler modules
2025-12-19 01:20:30 +08:00
fawney19
7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00
fawney19
c7b971cfe7 Merge pull request #29 from fawney19/feature/proxy-support
feat: add HTTP/SOCKS5 proxy support for API endpoints
2025-12-18 16:18:25 +08:00
fawney19
293bb592dc fix: enhance proxy configuration with password preservation and UI improvements
- Add 'enabled' field to ProxyConfig for preserving config when disabled
- Mask proxy password in API responses (return '***' instead of actual password)
- Preserve existing password on update when new password not provided
- Add URL encoding for proxy credentials (handle special chars like @, :, /)
- Enhanced URL validation: block SOCKS4, require valid host, forbid embedded auth
- UI improvements: use Switch component, dynamic password placeholder
- Add confirmation dialog for orphaned credentials (URL empty but has username/password)
- Prevent browser password autofill with randomized IDs and CSS text-security
- Unify ProxyConfig type definition in types.ts
2025-12-18 16:14:37 +08:00
fawney19
3e50c157be feat: add HTTP/SOCKS5 proxy support for API endpoints
- Add proxy field to ProviderEndpoint database model with migration
- Add ProxyConfig Pydantic model for proxy URL validation
- Extend HTTP client pool with create_client_with_proxy method
- Integrate proxy configuration in chat_handler_base.py and cli_handler_base.py
- Update admin API endpoints to support proxy configuration CRUD
- Add proxy configuration UI in frontend EndpointFormDialog

Fixes #28
2025-12-18 14:46:47 +08:00
fawney19
21587449c8 fix: improve error classification and logging system
- Enhance error classifier to properly handle API key failures with fallback support
- Add error reason/code parsing for better AWS and multi-provider compatibility
- Improve error message structure detection for non-standard formats
- Refactor file logging with size-based rotation (100MB) instead of daily
- Optimize production logging by disabling backtrace and diagnose
- Clean up model validation and remove redundant configurations
2025-12-18 10:57:31 +08:00
fawney19
3d0ab353d3 refactor: migrate Pydantic Config to v2 ConfigDict 2025-12-18 02:20:53 +08:00
fawney19
b2a857c164 refactor: consolidate transaction management and remove legacy modules
- Remove unused context.py module (replaced by request.state)
- Remove provider_cache.py (no longer needed)
- Unify environment loading in config/settings.py instead of __init__.py
- Add deprecation warning for get_async_db() (consolidating on sync Session)
- Enhance database.py documentation with comprehensive transaction strategy
- Simplify audit logging to reuse request-level Session (no separate connections)
- Extract UsageService._build_usage_params() helper to reduce code duplication
- Update model and user cache implementations with refined transaction handling
- Remove unnecessary sessionmaker from pipeline
- Clean up audit service exception handling
2025-12-18 01:59:40 +08:00
fawney19
4d1d863916 refactor: improve authentication and user data handling
- Replace user cache queries with direct database queries to ensure data consistency
- Fix token_type parameter in verify_token calls (access token verification)
- Fix role-based permission check using dictionary ranking instead of string comparison
- Fix logout operation to use correct JWT claim name (user_id instead of sub)
- Simplify user authentication flow by removing unnecessary cache layer
- Optimize session initialization in main.py using create_session helper
- Remove unused imports and exception variables
2025-12-18 01:09:22 +08:00
fawney19
b579420690 refactor: optimize database session lifecycle and middleware architecture
- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
2025-12-18 00:35:46 +08:00
fawney19
9d5c84f9d3 refactor: add scheduling mode support and optimize system settings UI
- Add fixed_order and cache_affinity scheduling modes to CacheAwareScheduler
- Only apply cache affinity in cache_affinity mode; use fixed order otherwise
- Simplify Dialog components with title/description props
- Remove unnecessary button shadows in SystemSettings
- Optimize import dialog UI structure
- Update ModelAliasesTab shadow styling
- Fix fallback orchestrator type hints
- Add scheduling_mode configuration in system config
2025-12-17 19:15:08 +08:00
fawney19
53e6a82480 Merge pull request #25 from fawney19/fix/22-dark-mode-settings-bug
fix: 修复个人设置页面深色模式切换后刷新失效的问题
2025-12-17 18:11:53 +08:00
fawney19
bd11ebdbd5 fix: 修复个人设置页面深色模式切换后刷新失效的问题
- 前端使用 useDarkMode composable 统一主题切换逻辑
- 后端支持 system 主题值(之前只支持 auto)
- 主题以本地 localStorage 为准,避免刷新时被服务端旧值覆盖

Fixes #22
2025-12-17 18:02:19 +08:00
fawney19
1dac4cb156 refactor: optimize provider query and stats aggregation logic 2025-12-17 16:41:10 +08:00
fawney19
50abb55c94 fix(models): clear form state when loading model data for edit
Reset model selection, search query, and expanded provider state
when switching to edit mode to prevent stale UI state carrying over
from previous operations. Also ensure tieredPricing is properly set
or reset based on model data.
2025-12-16 18:42:58 +08:00
fawney19
73d3c9d3e4 ui(models): display model ID in global model form dialog
Show model ID below model name in the dropdown list for better clarity
when selecting models, with appropriate text styling for selected state.
2025-12-16 18:36:23 +08:00
fawney19
d24c3885ab feat(admin): add config and user data import/export functionality
Add comprehensive import/export endpoints for:
- Provider and model configuration (with key decryption for export)
- User data and API keys (preserving encrypted data)

Includes merge modes (skip/overwrite/error) for conflict handling,
10MB size limit for imports, and automatic cache invalidation.

Also fix optional field in GlobalModelResponse tiered_pricing.
2025-12-16 18:33:14 +08:00
104 changed files with 7665 additions and 2541 deletions

View File

@@ -15,6 +15,8 @@ env:
REGISTRY: ghcr.io
BASE_IMAGE_NAME: fawney19/aether-base
APP_IMAGE_NAME: fawney19/aether
# Files that affect base image - used for hash calculation
BASE_FILES: "Dockerfile.base pyproject.toml frontend/package.json frontend/package-lock.json"
jobs:
check-base-changes:
@@ -23,8 +25,13 @@ jobs:
base_changed: ${{ steps.check.outputs.base_changed }}
steps:
- uses: actions/checkout@v4
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
fetch-depth: 2
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Check if base image needs rebuild
id: check
@@ -34,10 +41,26 @@ jobs:
exit 0
fi
# Check if base-related files changed
if git diff --name-only HEAD~1 HEAD | grep -qE '^(Dockerfile\.base|pyproject\.toml|frontend/package.*\.json)$'; then
# Calculate current hash of base-related files
CURRENT_HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
echo "Current base files hash: $CURRENT_HASH"
# Try to get hash label from remote image config
# Pull the image config and extract labels
REMOTE_HASH=""
if docker pull ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest 2>/dev/null; then
REMOTE_HASH=$(docker inspect ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest --format '{{ index .Config.Labels "org.opencontainers.image.base.hash" }}' 2>/dev/null) || true
fi
if [ -z "$REMOTE_HASH" ] || [ "$REMOTE_HASH" == "<no value>" ]; then
# No remote image or no hash label, need to rebuild
echo "No remote base image or hash label found, need rebuild"
echo "base_changed=true" >> $GITHUB_OUTPUT
elif [ "$CURRENT_HASH" != "$REMOTE_HASH" ]; then
echo "Hash mismatch: remote=$REMOTE_HASH, current=$CURRENT_HASH"
echo "base_changed=true" >> $GITHUB_OUTPUT
else
echo "Hash matches, no rebuild needed"
echo "base_changed=false" >> $GITHUB_OUTPUT
fi
@@ -61,6 +84,12 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Calculate base files hash
id: hash
run: |
HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
echo "hash=$HASH" >> $GITHUB_OUTPUT
- name: Extract metadata for base image
id: meta
uses: docker/metadata-action@v5
@@ -69,6 +98,8 @@ jobs:
tags: |
type=raw,value=latest
type=sha,prefix=
labels: |
org.opencontainers.image.base.hash=${{ steps.hash.outputs.hash }}
- name: Build and push base image
uses: docker/build-push-action@v5
@@ -117,7 +148,7 @@ jobs:
- name: Update Dockerfile.app to use registry base image
run: |
sed -i "s|FROM aether-base:latest|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest|g" Dockerfile.app
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
- name: Build and push app image
uses: docker/build-push-action@v5

View File

@@ -1,16 +1,134 @@
# 应用镜像:基于基础镜像,只复制代码(秒级构建)
# 运行镜像:从 base 提取产物到精简运行时
# 构建命令: docker build -f Dockerfile.app -t aether-app:latest .
FROM aether-base:latest
# 用于 GitHub Actions CI官方源
FROM aether-base:latest AS builder
WORKDIR /app
# 复制前端源码并构建
COPY frontend/ ./frontend/
RUN cd frontend && npm run build
# ==================== 运行时镜像 ====================
FROM python:3.12-slim
WORKDIR /app
# 运行时依赖(无 gcc/nodejs/npm
RUN apt-get update && apt-get install -y \
nginx \
supervisor \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/*
# 从 base 镜像复制 Python 包
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
# 只复制需要的 Python 可执行文件
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
# 从 builder 阶段复制前端构建产物
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
# 复制后端代码
COPY src/ ./src/
COPY alembic.ini ./
COPY alembic/ ./alembic/
# 构建前端(使用基础镜像中已安装的 node_modules
COPY frontend/ /tmp/frontend/
RUN cd /tmp/frontend && npm run build && \
cp -r dist/* /usr/share/nginx/html/ && \
rm -rf /tmp/frontend
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' gzip off;' \
' add_header X-Accel-Buffering no;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

135
Dockerfile.app.local Normal file
View File

@@ -0,0 +1,135 @@
# 运行镜像:从 base 提取产物到精简运行时(国内镜像源版本)
# 构建命令: docker build -f Dockerfile.app.local -t aether-app:latest .
# 用于本地/国内服务器部署
FROM aether-base:latest AS builder
WORKDIR /app
# 复制前端源码并构建
COPY frontend/ ./frontend/
RUN cd frontend && npm run build
# ==================== 运行时镜像 ====================
FROM python:3.12-slim
WORKDIR /app
# 运行时依赖(使用清华镜像源)
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
apt-get update && apt-get install -y \
nginx \
supervisor \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/*
# 从 base 镜像复制 Python 包
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
# 只复制需要的 Python 可执行文件
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
# 从 builder 阶段复制前端构建产物
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
# 复制后端代码
COPY src/ ./src/
COPY alembic.ini ./
COPY alembic/ ./alembic/
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' gzip off;' \
' add_header X-Accel-Buffering no;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -1,122 +1,25 @@
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
# 构建镜像:编译环境 + 预编译的依赖
# 用于 GitHub Actions CI 构建(不使用国内镜像源)
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
FROM python:3.12-slim
WORKDIR /app
# 系统依赖
# 构建工具
RUN apt-get update && apt-get install -y \
nginx \
supervisor \
libpq-dev \
gcc \
curl \
gettext-base \
nodejs \
npm \
&& rm -rf /var/lib/apt/lists/*
# Python 依赖(安装到系统,不用 -e 模式)
# Python 依赖
COPY pyproject.toml README.md ./
RUN mkdir -p src && touch src/__init__.py && \
pip install --no-cache-dir .
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
pip cache purge
# 前端依赖
COPY frontend/package*.json /tmp/frontend/
WORKDIR /tmp/frontend
RUN npm ci
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
WORKDIR /app
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
# 前端依赖(只安装,不构建)
COPY frontend/package*.json ./frontend/
RUN cd frontend && npm ci

View File

@@ -1,18 +1,15 @@
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
# 构建镜像:编译环境 + 预编译的依赖(国内镜像源版本)
# 构建命令: docker build -f Dockerfile.base.local -t aether-base:latest .
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
FROM python:3.12-slim
WORKDIR /app
# 系统依赖
# 构建工具(使用清华镜像源)
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
apt-get update && apt-get install -y \
nginx \
supervisor \
libpq-dev \
gcc \
curl \
gettext-base \
nodejs \
npm \
&& rm -rf /var/lib/apt/lists/*
@@ -20,107 +17,12 @@ RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.li
# pip 镜像源
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# Python 依赖(安装到系统,不用 -e 模式)
# Python 依赖
COPY pyproject.toml README.md ./
RUN mkdir -p src && touch src/__init__.py && \
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir .
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
pip cache purge
# 前端依赖
COPY frontend/package*.json /tmp/frontend/
WORKDIR /tmp/frontend
RUN npm config set registry https://registry.npmmirror.com && npm ci
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
WORKDIR /app
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
# 前端依赖(只安装,不构建,使用淘宝镜像源)
COPY frontend/package*.json ./frontend/
RUN cd frontend && npm config set registry https://registry.npmmirror.com && npm ci

View File

@@ -20,10 +20,10 @@ depends_on = None
def upgrade() -> None:
# Create ENUM types
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
# Create ENUM types (with IF NOT EXISTS for idempotency)
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
op.execute(
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
)
# ==================== users ====================
@@ -35,7 +35,7 @@ def upgrade() -> None:
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column(
"role",
sa.Enum("admin", "user", name="userrole", create_type=False),
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
nullable=False,
server_default="user",
),
@@ -67,7 +67,7 @@ def upgrade() -> None:
sa.Column("website", sa.String(500), nullable=True),
sa.Column(
"billing_type",
sa.Enum(
postgresql.ENUM(
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
),
nullable=False,

View File

@@ -0,0 +1,57 @@
"""add proxy field to provider_endpoints
Revision ID: f30f9936f6a2
Revises: 1cc6942cf06f
Create Date: 2025-12-18 06:31:58.451112+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'f30f9936f6a2'
down_revision = '1cc6942cf06f'
branch_labels = None
depends_on = None
def column_exists(table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns
def get_column_type(table_name: str, column_name: str) -> str:
"""获取列的类型"""
bind = op.get_bind()
inspector = inspect(bind)
for col in inspector.get_columns(table_name):
if col['name'] == column_name:
return str(col['type']).upper()
return ''
def upgrade() -> None:
"""添加 proxy 字段到 provider_endpoints 表"""
if not column_exists('provider_endpoints', 'proxy'):
# 字段不存在,直接添加 JSONB 类型
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
else:
# 字段已存在,检查是否需要转换类型
col_type = get_column_type('provider_endpoints', 'proxy')
if 'JSONB' not in col_type:
# 如果是 JSON 类型,转换为 JSONB
op.execute(
'ALTER TABLE provider_endpoints '
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
)
def downgrade() -> None:
"""移除 proxy 字段"""
if column_exists('provider_endpoints', 'proxy'):
op.drop_column('provider_endpoints', 'proxy')

View File

@@ -21,9 +21,9 @@ HASH_FILE=".deps-hash"
CODE_HASH_FILE=".code-hash"
MIGRATION_HASH_FILE=".migration-hash"
# 计算依赖文件的哈希值
# 计算依赖文件的哈希值(包含 Dockerfile.base.local
calc_deps_hash() {
cat pyproject.toml frontend/package.json frontend/package-lock.json 2>/dev/null | md5sum | cut -d' ' -f1
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
}
# 计算代码文件的哈希值
@@ -88,7 +88,7 @@ build_base() {
# 构建应用镜像
build_app() {
echo ">>> Building app image (code only)..."
docker build -f Dockerfile.app -t aether-app:latest .
docker build -f Dockerfile.app.local -t aether-app:latest .
save_code_hash
}
@@ -162,25 +162,32 @@ git pull
# 标记是否需要重启
NEED_RESTART=false
BASE_REBUILT=false
# 检查基础镜像是否存在,或依赖是否变化
if ! docker image inspect aether-base:latest >/dev/null 2>&1; then
echo ">>> Base image not found, building..."
build_base
BASE_REBUILT=true
NEED_RESTART=true
elif check_deps_changed; then
echo ">>> Dependencies changed, rebuilding base image..."
build_base
BASE_REBUILT=true
NEED_RESTART=true
else
echo ">>> Dependencies unchanged."
fi
# 检查代码是否变化
# 检查代码是否变化,或者 base 重建了app 依赖 base
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
echo ">>> App image not found, building..."
build_app
NEED_RESTART=true
elif [ "$BASE_REBUILT" = true ]; then
echo ">>> Base image rebuilt, rebuilding app image..."
build_app
NEED_RESTART=true
elif check_code_changed; then
echo ">>> Code changed, rebuilding app image..."
build_app

View File

@@ -41,7 +41,7 @@ services:
app:
build:
context: .
dockerfile: Dockerfile.app
dockerfile: Dockerfile.app.local
image: aether-app:latest
container_name: aether-app
environment:

View File

@@ -1,5 +1,179 @@
import apiClient from './client'
// 配置导出数据结构
export interface ConfigExportData {
version: string
exported_at: string
global_models: GlobalModelExport[]
providers: ProviderExport[]
}
// 用户导出数据结构
export interface UsersExportData {
version: string
exported_at: string
users: UserExport[]
}
export interface UserExport {
email: string
username: string
password_hash: string
role: string
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_models?: string[] | null
model_capability_settings?: any
quota_usd?: number | null
used_usd?: number
total_usd?: number
is_active: boolean
api_keys: UserApiKeyExport[]
}
export interface UserApiKeyExport {
key_hash: string
key_encrypted?: string | null
name?: string | null
is_standalone: boolean
balance_used_usd?: number
current_balance_usd?: number | null
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
auto_delete_on_expiry?: boolean
total_requests?: number
total_cost_usd?: number
}
export interface GlobalModelExport {
name: string
display_name: string
default_price_per_request?: number | null
default_tiered_pricing: any
supported_capabilities?: string[] | null
config?: any
is_active: boolean
}
export interface ProviderExport {
name: string
display_name: string
description?: string | null
website?: string | null
billing_type?: string | null
monthly_quota_usd?: number | null
quota_reset_day?: number
rpm_limit?: number | null
provider_priority?: number
is_active: boolean
rate_limit?: number | null
concurrent_limit?: number | null
config?: any
endpoints: EndpointExport[]
models: ModelExport[]
}
export interface EndpointExport {
api_format: string
base_url: string
headers?: any
timeout?: number
max_retries?: number
max_concurrent?: number | null
rate_limit?: number | null
is_active: boolean
custom_path?: string | null
config?: any
keys: KeyExport[]
}
export interface KeyExport {
api_key: string
name?: string | null
note?: string | null
rate_multiplier?: number
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null
rate_limit?: number | null
daily_limit?: number | null
monthly_limit?: number | null
allowed_models?: string[] | null
capabilities?: any
is_active: boolean
}
export interface ModelExport {
global_model_name: string | null
provider_model_name: string
provider_model_aliases?: any
price_per_request?: number | null
tiered_pricing?: any
supports_vision?: boolean | null
supports_function_calling?: boolean | null
supports_streaming?: boolean | null
supports_extended_thinking?: boolean | null
supports_image_generation?: boolean | null
is_active: boolean
config?: any
}
// Provider 模型查询响应
export interface ProviderModelsQueryResponse {
success: boolean
data: {
models: Array<{
id: string
object?: string
created?: number
owned_by?: string
display_name?: string
api_format?: string
}>
error?: string
}
provider: {
id: string
name: string
display_name: string
}
}
export interface ConfigImportRequest extends ConfigExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportRequest extends UsersExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportResponse {
message: string
stats: {
users: { created: number; updated: number; skipped: number }
api_keys: { created: number; skipped: number }
errors: string[]
}
}
export interface ConfigImportResponse {
message: string
stats: {
global_models: { created: number; updated: number; skipped: number }
providers: { created: number; updated: number; skipped: number }
endpoints: { created: number; updated: number; skipped: number }
keys: { created: number; updated: number; skipped: number }
models: { created: number; updated: number; skipped: number }
errors: string[]
}
}
// API密钥管理相关接口定义
export interface AdminApiKey {
id: string // UUID
@@ -173,5 +347,44 @@ export const adminApi = {
'/api/admin/system/api-formats'
)
return response.data
},
// 导出配置
async exportConfig(): Promise<ConfigExportData> {
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
return response.data
},
// 导入配置
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
const response = await apiClient.post<ConfigImportResponse>(
'/api/admin/system/config/import',
data
)
return response.data
},
// 导出用户数据
async exportUsers(): Promise<UsersExportData> {
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
return response.data
},
// 导入用户数据
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
const response = await apiClient.post<UsersImportResponse>(
'/api/admin/system/users/import',
data
)
return response.data
},
// 查询 Provider 可用模型(从上游 API 获取)
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
const response = await apiClient.post<ProviderModelsQueryResponse>(
'/api/admin/provider-query/models',
{ provider_id: providerId, api_key_id: apiKeyId }
)
return response.data
}
}

View File

@@ -66,6 +66,7 @@ export interface UserAffinity {
key_name: string | null
key_prefix: string | null // Provider Key 脱敏显示前4...后4
rate_multiplier: number
global_model_id: string | null // 原始的 global_model_id用于删除
model_name: string | null // 模型名称(如 claude-haiku-4-5-20250514
model_display_name: string | null // 模型显示名称(如 Claude Haiku 4.5
api_format: string | null // API 格式 (claude/openai)
@@ -119,6 +120,18 @@ export const cacheApi = {
await api.delete(`/api/admin/monitoring/cache/users/${userIdentifier}`)
},
/**
* 清除单条缓存亲和性
*
* @param affinityKey API Key ID
* @param endpointId Endpoint ID
* @param modelId GlobalModel ID
* @param apiFormat API 格式 (claude/openai)
*/
async clearSingleAffinity(affinityKey: string, endpointId: string, modelId: string, apiFormat: string): Promise<void> {
await api.delete(`/api/admin/monitoring/cache/affinity/${affinityKey}/${endpointId}/${modelId}/${apiFormat}`)
},
/**
* 清除所有缓存
*/

View File

@@ -1,5 +1,5 @@
import client from '../client'
import type { ProviderEndpoint } from './types'
import type { ProviderEndpoint, ProxyConfig } from './types'
/**
* 获取指定 Provider 的所有 Endpoints
@@ -38,6 +38,7 @@ export async function createEndpoint(
rate_limit?: number
is_active?: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
}
): Promise<ProviderEndpoint> {
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
@@ -63,6 +64,7 @@ export async function updateEndpoint(
rate_limit: number
is_active: boolean
config: Record<string, any>
proxy: ProxyConfig | null
}>
): Promise<ProviderEndpoint> {
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)

View File

@@ -1,3 +1,35 @@
// API 格式常量
export const API_FORMATS = {
CLAUDE: 'CLAUDE',
CLAUDE_CLI: 'CLAUDE_CLI',
OPENAI: 'OPENAI',
OPENAI_CLI: 'OPENAI_CLI',
GEMINI: 'GEMINI',
GEMINI_CLI: 'GEMINI_CLI',
} as const
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
// API 格式显示名称映射按品牌分组API 在前CLI 在后)
export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.CLAUDE]: 'Claude',
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
[API_FORMATS.OPENAI]: 'OpenAI',
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
[API_FORMATS.GEMINI]: 'Gemini',
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
}
/**
* 代理配置类型
*/
export interface ProxyConfig {
url: string
username?: string
password?: string
enabled?: boolean // 是否启用代理false 时保留配置但不使用)
}
export interface ProviderEndpoint {
id: string
provider_id: string
@@ -19,6 +51,7 @@ export interface ProviderEndpoint {
last_failure_at?: string
is_active: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
total_keys: number
active_keys: number
created_at: string
@@ -214,6 +247,7 @@ export interface ConcurrencyStatus {
export interface ProviderModelAlias {
name: string
priority: number // 优先级(数字越小优先级越高)
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
}
export interface Model {

View File

@@ -34,11 +34,10 @@ const buttonClass = computed(() => {
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
const variantClasses = {
default:
'bg-primary text-white shadow-[0_20px_35px_rgba(204,120,92,0.35)] hover:bg-primary/90 hover:shadow-[0_25px_45px_rgba(204,120,92,0.45)]',
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85 shadow-sm',
default: 'bg-primary text-white hover:bg-primary/90',
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85',
outline:
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 shadow-sm backdrop-blur transition-all',
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 backdrop-blur transition-all',
secondary:
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
ghost: 'hover:bg-accent hover:text-accent-foreground',

View File

@@ -2,7 +2,7 @@
<Teleport to="body">
<div
v-if="isOpen"
class="fixed inset-0 overflow-y-auto"
class="fixed inset-0 overflow-y-auto pointer-events-none"
:style="{ zIndex: containerZIndex }"
>
<!-- 背景遮罩 -->
@@ -16,13 +16,13 @@
>
<div
v-if="isOpen"
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity"
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
:style="{ zIndex: backdropZIndex }"
@click="handleClose"
/>
</Transition>
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
<!-- 对话框内容 -->
<Transition
enter-active-class="duration-300 ease-out"
@@ -34,7 +34,7 @@
>
<div
v-if="isOpen"
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border"
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border pointer-events-auto"
:style="{ zIndex: contentZIndex }"
:class="maxWidthClass"
@click.stop

View File

@@ -45,7 +45,7 @@ const props = withDefaults(defineProps<Props>(), {
const contentClass = computed(() =>
cn(
'z-[100] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
'z-[200] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
props.class

View File

@@ -132,7 +132,7 @@
type="number"
min="1"
max="10000"
placeholder="100"
placeholder="留空不限制"
class="h-10"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
/>
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -389,7 +389,7 @@ function resetForm() {
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -408,7 +408,7 @@ function loadKeyData() {
initial_balance_usd: props.apiKey.initial_balance_usd,
expire_days: props.apiKey.expire_days,
never_expire: props.apiKey.never_expire,
rate_limit: props.apiKey.rate_limit || 100,
rate_limit: props.apiKey.rate_limit,
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
allowed_providers: props.apiKey.allowed_providers || [],
allowed_api_formats: props.apiKey.allowed_api_formats || [],

View File

@@ -68,13 +68,19 @@
<div
v-for="model in group.models"
:key="model.modelId"
class="flex items-center gap-2 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'bg-primary text-primary-foreground'
: 'hover:bg-muted'"
@click="selectModel(model)"
>
<span class="truncate">{{ model.modelName }}</span>
<span class="truncate font-medium">{{ model.modelName }}</span>
<span
class="truncate text-[10px]"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'text-primary-foreground/70'
: 'text-muted-foreground'"
>{{ model.modelId }}</span>
</div>
</div>
</div>
@@ -390,15 +396,13 @@ interface ProviderGroup {
const groupedModels = computed(() => {
let models = allModels.value.filter(m => !m.deprecated)
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
models = models.filter(model =>
model.providerId.toLowerCase().includes(query) ||
model.providerName.toLowerCase().includes(query) ||
model.modelId.toLowerCase().includes(query) ||
model.modelName.toLowerCase().includes(query) ||
model.family?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
models = models.filter(model => {
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 按提供商分组
@@ -415,14 +419,16 @@ const groupedModels = computed(() => {
}
// 转换为数组并排序
let result = Array.from(groups.values())
const result = Array.from(groups.values())
// 如果有搜索词,把提供商名称/ID匹配的排在前面
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result.sort((a, b) => {
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
const aProviderMatch = keywords.some(k => aText.includes(k))
const bProviderMatch = keywords.some(k => bText.includes(k))
if (aProviderMatch && !bProviderMatch) return -1
if (!aProviderMatch && bProviderMatch) return 1
return a.providerName.localeCompare(b.providerName)
@@ -598,6 +604,11 @@ function resetForm() {
// 加载模型数据(编辑模式)
function loadModelData() {
if (!props.model) return
// 先重置创建模式的残留状态
selectedModel.value = null
searchQuery.value = ''
expandedProvider.value = null
form.value = {
name: props.model.name,
display_name: props.model.display_name,
@@ -606,9 +617,10 @@ function loadModelData() {
config: props.model.config ? { ...props.model.config } : { streaming: true },
is_active: props.model.is_active,
}
if (props.model.default_tiered_pricing) {
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
}
// 确保 tieredPricing 也被正确设置或重置
tieredPricing.value = props.model.default_tiered_pricing
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
: null
}
// 使用 useFormDialog 统一处理对话框逻辑

View File

@@ -9,7 +9,7 @@
>
<form
class="space-y-6"
@submit.prevent="handleSubmit"
@submit.prevent="handleSubmit()"
>
<!-- API 配置 -->
<div class="space-y-4">
@@ -132,6 +132,79 @@
</div>
</div>
</div>
<!-- 代理配置 -->
<div class="space-y-4">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch v-model="proxyEnabled" />
<span class="text-sm text-muted-foreground">启用代理</span>
</div>
</div>
<div
v-if="proxyEnabled"
class="space-y-4 rounded-lg border p-4"
>
<div class="space-y-2">
<Label for="proxy_url">代理 URL *</Label>
<Input
id="proxy_url"
v-model="form.proxy_url"
placeholder="http://host:port 或 socks5://host:port"
required
:class="proxyUrlError ? 'border-red-500' : ''"
/>
<p
v-if="proxyUrlError"
class="text-xs text-red-500"
>
{{ proxyUrlError }}
</p>
<p
v-else
class="text-xs text-muted-foreground"
>
支持 HTTPHTTPSSOCKS5 代理
</p>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="proxy_user">用户名可选</Label>
<Input
:id="`proxy_user_${formId}`"
:name="`proxy_user_${formId}`"
v-model="form.proxy_username"
placeholder="代理认证用户名"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-2">
<Label :for="`proxy_pass_${formId}`">密码可选</Label>
<Input
:id="`proxy_pass_${formId}`"
:name="`proxy_pass_${formId}`"
v-model="form.proxy_password"
type="text"
:placeholder="passwordPlaceholder"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
/>
</div>
</div>
</div>
</div>
</form>
<template #footer>
@@ -145,12 +218,24 @@
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
@click="handleSubmit"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
</Button>
</template>
</Dialog>
<!-- 确认清空凭据对话框 -->
<AlertDialog
v-model="showClearCredentialsDialog"
title="清空代理凭据"
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
type="warning"
confirm-text="清空并保存"
cancel-text="返回编辑"
@confirm="confirmClearCredentials"
@cancel="showClearCredentialsDialog = false"
/>
</template>
<script setup lang="ts">
@@ -165,7 +250,9 @@ import {
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue'
import { Link, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
@@ -194,6 +281,11 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
// 生成随机 ID 防止浏览器自动填充
const formId = Math.random().toString(36).substring(2, 10)
// 内部状态
const internalOpen = computed(() => props.modelValue)
@@ -207,7 +299,11 @@ const form = ref({
max_retries: 3,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true
is_active: true,
// 代理配置
proxy_url: '',
proxy_username: '',
proxy_password: '',
})
// API 格式列表
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
const hasExistingPassword = computed(() => {
if (!props.endpoint?.proxy) return false
const proxy = props.endpoint.proxy as { password?: string }
return proxy?.password === MASKED_PASSWORD
})
// 密码输入框的 placeholder
const passwordPlaceholder = computed(() => {
if (hasExistingPassword.value) {
return '已保存密码,留空保持不变'
}
return '代理认证密码'
})
// 代理 URL 验证
const proxyUrlError = computed(() => {
// 只有启用代理且填写了 URL 时才验证
if (!proxyEnabled.value || !form.value.proxy_url) {
return ''
}
const url = form.value.proxy_url.trim()
// 检查禁止的特殊字符
if (/[\n\r]/.test(url)) {
return '代理 URL 包含非法字符'
}
// 验证协议(不支持 SOCKS4
if (!/^(http|https|socks5):\/\//i.test(url)) {
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
}
try {
const parsed = new URL(url)
if (!parsed.host) {
return '代理 URL 必须包含有效的 host'
}
// 禁止 URL 中内嵌认证信息
if (parsed.username || parsed.password) {
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
}
} catch {
return '代理 URL 格式无效'
}
return ''
})
// 组件挂载时加载API格式
onMounted(() => {
loadApiFormats()
@@ -252,14 +395,23 @@ function resetForm() {
max_retries: 3,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true
is_active: true,
proxy_url: '',
proxy_username: '',
proxy_password: '',
}
proxyEnabled.value = false
}
// 原始密码占位符(后端返回的脱敏标记)
const MASKED_PASSWORD = '***'
// 加载端点数据(编辑模式)
function loadEndpointData() {
if (!props.endpoint) return
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
form.value = {
api_format: props.endpoint.api_format,
base_url: props.endpoint.base_url,
@@ -268,8 +420,15 @@ function loadEndpointData() {
max_retries: props.endpoint.max_retries,
max_concurrent: props.endpoint.max_concurrent || undefined,
rate_limit: props.endpoint.rate_limit || undefined,
is_active: props.endpoint.is_active
is_active: props.endpoint.is_active,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
}
// 根据 enabled 字段或 url 存在判断是否启用代理
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
}
// 使用 useFormDialog 统一处理对话框逻辑
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
resetForm,
})
// 构建代理配置
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
// - 无 URL 时返回 null
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
if (!form.value.proxy_url) {
// 没填 URL无代理配置
return null
}
return {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: proxyEnabled.value, // 开关状态决定是否启用
}
}
// 提交表单
const handleSubmit = async () => {
const handleSubmit = async (skipCredentialCheck = false) => {
if (!props.provider && !props.endpoint) return
// 只在开关开启且填写了 URL 时验证
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
showError(proxyUrlError.value, '代理配置错误')
return
}
// 检查:开关开启但没有 URL却有用户名或密码
const hasOrphanedCredentials = proxyEnabled.value
&& !form.value.proxy_url
&& (form.value.proxy_username || form.value.proxy_password)
if (hasOrphanedCredentials && !skipCredentialCheck) {
// 弹出确认对话框
showClearCredentialsDialog.value = true
return
}
loading.value = true
try {
const proxyConfig = buildProxyConfig()
if (isEditMode.value && props.endpoint) {
// 更新端点
await updateEndpoint(props.endpoint.id, {
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点已更新', '保存成功')
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点创建成功', '成功')
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
loading.value = false
}
}
// 确认清空凭据并继续保存
const confirmClearCredentials = () => {
form.value.proxy_username = ''
form.value.proxy_password = ''
showClearCredentialsDialog.value = false
handleSubmit(true) // 跳过凭据检查,直接提交
}
</script>

View File

@@ -312,8 +312,41 @@
<template #footer>
<div class="flex items-center justify-between w-full">
<div class="text-xs text-muted-foreground">
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
<div class="flex items-center gap-4">
<div class="text-xs text-muted-foreground">
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
</div>
<div class="flex items-center gap-2 pl-4 border-l border-border">
<span class="text-xs text-muted-foreground">调度:</span>
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'cache_affinity'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="优先使用已缓存的Provider利用Prompt Cache"
@click="schedulingMode = 'cache_affinity'"
>
缓存亲和
</button>
</div>
</div>
</div>
<div class="flex gap-2">
<Button
@@ -410,6 +443,9 @@ const saving = ref(false)
// Key 优先级编辑状态
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
// 调度模式状态
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
// 可用的 API 格式
const availableFormats = computed(() => {
return Object.keys(keysByFormat.value).sort()
@@ -433,11 +469,18 @@ watch(internalOpen, async (open) => {
// 加载当前的优先级模式配置
async function loadCurrentPriorityMode() {
try {
const response = await adminApi.getSystemConfig('provider_priority_mode')
const currentMode = response.value || 'provider'
const [priorityResponse, schedulingResponse] = await Promise.all([
adminApi.getSystemConfig('provider_priority_mode'),
adminApi.getSystemConfig('scheduling_mode')
])
const currentMode = priorityResponse.value || 'provider'
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
} catch {
activeMainTab.value = 'provider'
schedulingMode.value = 'cache_affinity'
}
}
@@ -611,11 +654,19 @@ async function save() {
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
await adminApi.updateSystemConfig(
'provider_priority_mode',
newMode,
'Provider/Key 优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)'
)
// 保存优先级模式和调度模式
await Promise.all([
adminApi.updateSystemConfig(
'provider_priority_mode',
newMode,
'Provider/Key 优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)'
),
adminApi.updateSystemConfig(
'scheduling_mode',
schedulingMode.value,
'调度模式fixed_order(固定顺序模式) 或 cache_affinity(缓存亲和模式)'
)
])
const providerUpdates = sortedProviders.value.map((provider, index) =>
updateProvider(provider.id, { provider_priority: index + 1 })

View File

@@ -526,7 +526,14 @@
@edit-model="handleEditModel"
@delete-model="handleDeleteModel"
@batch-assign="handleBatchAssign"
@manage-alias="handleManageAlias"
/>
<!-- 模型名称映射 -->
<ModelAliasesTab
v-if="provider"
:key="`aliases-${provider.id}`"
:provider="provider"
@refresh="handleRelatedDataRefresh"
/>
</div>
</template>
@@ -629,16 +636,6 @@
@update:open="batchAssignDialogOpen = $event"
@changed="handleBatchAssignChanged"
/>
<!-- 模型别名管理对话框 -->
<ModelAliasDialog
v-if="open && provider"
:open="aliasDialogOpen"
:provider-id="provider.id"
:model="aliasEditingModel"
@update:open="aliasDialogOpen = $event"
@saved="handleAliasSaved"
/>
</template>
<script setup lang="ts">
@@ -667,8 +664,8 @@ import {
KeyFormDialog,
KeyAllowedModelsDialog,
ModelsTab,
BatchAssignModelsDialog,
ModelAliasDialog
ModelAliasesTab,
BatchAssignModelsDialog
} from '@/features/providers/components'
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false)
// 别名管理相关状态
const aliasDialogOpen = ref(false)
const aliasEditingModel = ref<Model | null>(null)
// 拖动排序相关状态
const dragState = ref({
isDragging: false,
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value ||
deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value ||
aliasDialogOpen.value
batchAssignDialogOpen.value
)
// 监听 providerId 变化
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
keyAllowedModelsDialogOpen.value = false
deleteKeyConfirmOpen.value = false
batchAssignDialogOpen.value = false
aliasDialogOpen.value = false
// 重置临时数据
endpointToEdit.value = null
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
emit('refresh')
}
// 处理管理映射 - 打开别名对话框
function handleManageAlias(model: Model) {
aliasEditingModel.value = model
aliasDialogOpen.value = true
}
// 处理别名保存完成
async function handleAliasSaved() {
aliasEditingModel.value = null
await loadProvider()
emit('refresh')
}
// 处理模型保存完成
async function handleModelSaved() {
editingModel.value = null

View File

@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'

File diff suppressed because it is too large Load Diff

View File

@@ -165,15 +165,6 @@
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="管理映射"
@click="openAliasDialog(model)"
>
<Tag class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
@@ -218,7 +209,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast'
@@ -233,7 +224,6 @@ const emit = defineEmits<{
'editModel': [model: Model]
'deleteModel': [model: Model]
'batchAssign': []
'manageAlias': [model: Model]
}>()
const { error: showError, success: showSuccess } = useToast()
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
emit('batchAssign')
}
// 打开别名管理对话框
function openAliasDialog(model: Model) {
emit('manageAlias', model)
}
// 切换模型启用状态
async function toggleModelActive(model: Model) {
if (togglingModelId.value) return

View File

@@ -25,7 +25,7 @@
</h3>
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
<span>{{ detail?.model || '-' }}</span>
<template v-if="detail?.target_model">
<template v-if="detail?.target_model && detail.target_model !== detail.model">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"

View File

@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
const filteredApiKeys = computed(() => {
let result = apiKeys.value
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(key =>
(key.name && key.name.toLowerCase().includes(query)) ||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
(key.username && key.username.toLowerCase().includes(query)) ||
(key.user_email && key.user_email.toLowerCase().includes(query))
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(key => {
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 状态筛选

View File

@@ -142,32 +142,37 @@ async function resetAffinitySearch() {
await fetchAffinityList()
}
async function clearUserCache(identifier: string, displayName?: string) {
const target = identifier?.trim()
if (!target) {
showError('无法识别标识符')
async function clearSingleAffinity(item: UserAffinity) {
const affinityKey = item.affinity_key?.trim()
const endpointId = item.endpoint_id?.trim()
const modelId = item.global_model_id?.trim()
const apiFormat = item.api_format?.trim()
if (!affinityKey || !endpointId || !modelId || !apiFormat) {
showError('缓存记录信息不完整,无法删除')
return
}
const label = displayName || target
const label = item.user_api_key_name || affinityKey
const modelLabel = item.model_display_name || item.model_name || modelId
const confirmed = await showConfirm({
title: '确认清除',
message: `确定要清除 ${label} 的缓存吗?`,
message: `确定要清除 ${label} 在模型 ${modelLabel} 上的缓存亲和性吗?`,
confirmText: '确认清除',
variant: 'destructive'
})
if (!confirmed) return
clearingRowAffinityKey.value = target
clearingRowAffinityKey.value = affinityKey
try {
await cacheApi.clearUserCache(target)
await cacheApi.clearSingleAffinity(affinityKey, endpointId, modelId, apiFormat)
showSuccess('清除成功')
await fetchCacheStats()
await fetchAffinityList(tableKeyword.value.trim() || undefined)
} catch (error) {
showError('清除失败')
log.error('清除用户缓存失败', error)
log.error('清除单条缓存失败', error)
} finally {
clearingRowAffinityKey.value = null
}
@@ -618,7 +623,7 @@ onBeforeUnmount(() => {
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
:disabled="clearingRowAffinityKey === item.affinity_key"
title="清除缓存"
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
@click="clearSingleAffinity(item)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
@@ -668,7 +673,7 @@ onBeforeUnmount(() => {
variant="ghost"
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive shrink-0"
:disabled="clearingRowAffinityKey === item.affinity_key"
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
@click="clearSingleAffinity(item)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>

View File

@@ -1002,13 +1002,13 @@ async function batchRemoveSelectedProviders() {
const filteredGlobalModels = computed(() => {
let result = globalModels.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选

View File

@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
const filteredProviders = computed(() => {
let result = [...providers.value]
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value.trim()) {
const query = searchQuery.value.toLowerCase()
result = result.filter(p =>
p.display_name.toLowerCase().includes(query) ||
p.name.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(p => {
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 排序

View File

@@ -7,6 +7,7 @@
<template #actions>
<Button
:disabled="loading"
class="shadow-none hover:shadow-none"
@click="saveSystemConfig"
>
{{ loading ? '保存中...' : '保存所有配置' }}
@@ -15,6 +16,94 @@
</PageHeader>
<div class="mt-6 space-y-6">
<!-- 配置导出/导入 -->
<CardSection
title="配置管理"
description="导出或导入提供商和模型配置,便于备份或迁移"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出当前所有提供商端点API Key 和模型配置到 JSON 文件
</p>
<Button
variant="outline"
:disabled="exportLoading"
@click="handleExportConfig"
>
<Download class="w-4 h-4 mr-2" />
{{ exportLoading ? '导出中...' : '导出配置' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入配置支持跳过覆盖或报错三种冲突处理模式
</p>
<div class="flex items-center gap-2">
<input
ref="configFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleConfigFileSelect"
>
<Button
variant="outline"
:disabled="importLoading"
@click="triggerConfigFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importLoading ? '导入中...' : '导入配置' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 用户数据导出/导入 -->
<CardSection
title="用户数据管理"
description="导出或导入用户及其 API Keys 数据(不含管理员)"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出所有普通用户及其 API Keys JSON 文件
</p>
<Button
variant="outline"
:disabled="exportUsersLoading"
@click="handleExportUsers"
>
<Download class="w-4 h-4 mr-2" />
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入用户数据需相同 ENCRYPTION_KEY
</p>
<div class="flex items-center gap-2">
<input
ref="usersFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleUsersFileSelect"
>
<Button
variant="outline"
:disabled="importUsersLoading"
@click="triggerUsersFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 基础配置 -->
<CardSection
title="基础配置"
@@ -96,32 +185,13 @@
</div>
</CardSection>
<!-- API Key 管理配置 -->
<!-- 独立余额 Key 过期管理 -->
<CardSection
title="API Key 管理"
description="API Key 相关配置"
title="独立余额 Key 过期管理"
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div>
<Label
for="api-key-expire"
class="block text-sm font-medium"
>
API密钥过期天数
</Label>
<Input
id="api-key-expire"
v-model.number="systemConfig.api_key_expire_days"
type="number"
placeholder="0"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
0 表示永不过期
</p>
</div>
<div class="flex items-center h-full pt-6">
<div class="flex items-center h-full">
<div class="flex items-center space-x-2">
<Checkbox
id="auto-delete-expired-keys"
@@ -135,7 +205,7 @@
自动删除过期 Key
</Label>
<p class="text-xs text-muted-foreground">
关闭时仅禁用过期 Key
关闭时仅禁用过期 Key不会物理删除
</p>
</div>
</div>
@@ -359,6 +429,25 @@
避免单次操作过大影响性能
</p>
</div>
<div>
<Label
for="audit-log-retention-days"
class="block text-sm font-medium"
>
审计日志保留天数
</Label>
<Input
id="audit-log-retention-days"
v-model.number="systemConfig.audit_log_retention_days"
type="number"
placeholder="30"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
超过后删除审计日志记录
</p>
</div>
</div>
<!-- 清理策略说明 -->
@@ -371,15 +460,318 @@
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储节省空间</p>
<p>3. <strong>统计阶段</strong>: 仅保留 tokens成本等统计信息</p>
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
<p>5. <strong>审计日志</strong>: 独立清理记录用户登录操作等安全事件</p>
</div>
</div>
</CardSection>
</div>
<!-- 导入配置对话框 -->
<Dialog
v-model:open="importDialogOpen"
title="导入配置"
description="选择冲突处理模式并确认导入"
>
<div class="space-y-4">
<div
v-if="importPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
配置预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>全局模型: {{ importPreview.global_models?.length || 0 }} </li>
<li>提供商: {{ importPreview.providers?.length || 0 }} </li>
<li>
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }}
</li>
<li>
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select
v-model="mergeMode"
v-model:open="mergeModeSelectOpen"
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有配置
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入配置替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="mergeMode === 'skip'">
已存在的配置将被保留仅导入新配置
</template>
<template v-else-if="mergeMode === 'overwrite'">
已存在的配置将被导入的配置覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<p class="text-xs text-muted-foreground">
注意相同的 API Keys 会自动跳过不会创建重复记录
</p>
</div>
<template #footer>
<Button
variant="outline"
@click="importDialogOpen = false; mergeModeSelectOpen = false"
>
取消
</Button>
<Button
:disabled="importLoading"
@click="confirmImport"
>
{{ importLoading ? '导入中...' : '确认导入' }}
</Button>
</template>
</Dialog>
<!-- 导入结果对话框 -->
<Dialog
v-model:open="importResultDialogOpen"
title="导入完成"
>
<div
v-if="importResult"
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
全局模型
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.global_models.created }},
更新: {{ importResult.stats.global_models.updated }},
跳过: {{ importResult.stats.global_models.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
提供商
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.providers.created }},
更新: {{ importResult.stats.providers.updated }},
跳过: {{ importResult.stats.providers.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
端点
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.endpoints.created }},
更新: {{ importResult.stats.endpoints.updated }},
跳过: {{ importResult.stats.endpoints.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.keys.created }},
跳过: {{ importResult.stats.keys.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg col-span-2">
<p class="font-medium">
模型配置
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.models.created }},
更新: {{ importResult.stats.models.updated }},
跳过: {{ importResult.stats.models.skipped }}
</p>
</div>
</div>
<div
v-if="importResult.stats.errors.length > 0"
class="p-3 bg-destructive/10 rounded-lg"
>
<p class="font-medium text-destructive mb-2">
警告信息
</p>
<ul class="text-sm text-destructive space-y-1">
<li
v-for="(err, index) in importResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<template #footer>
<Button @click="importResultDialogOpen = false">
确定
</Button>
</template>
</Dialog>
<!-- 用户数据导入对话框 -->
<Dialog
v-model:open="importUsersDialogOpen"
title="导入用户数据"
description="选择冲突处理模式并确认导入"
>
<div class="space-y-4">
<div
v-if="importUsersPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
数据预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>用户: {{ importUsersPreview.users?.length || 0 }} </li>
<li>
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select
v-model="usersMergeMode"
v-model:open="usersMergeModeSelectOpen"
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有用户
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入数据替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="usersMergeMode === 'skip'">
已存在的用户将被保留仅导入新用户
</template>
<template v-else-if="usersMergeMode === 'overwrite'">
已存在的用户将被导入的数据覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<p class="text-xs text-muted-foreground">
注意用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作
</p>
</div>
<template #footer>
<Button
variant="outline"
@click="importUsersDialogOpen = false; usersMergeModeSelectOpen = false"
>
取消
</Button>
<Button
:disabled="importUsersLoading"
@click="confirmImportUsers"
>
{{ importUsersLoading ? '导入中...' : '确认导入' }}
</Button>
</template>
</Dialog>
<!-- 用户数据导入结果对话框 -->
<Dialog
v-model:open="importUsersResultDialogOpen"
title="用户数据导入完成"
>
<div
v-if="importUsersResult"
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
用户
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.users.created }},
更新: {{ importUsersResult.stats.users.updated }},
跳过: {{ importUsersResult.stats.users.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.api_keys.created }},
跳过: {{ importUsersResult.stats.api_keys.skipped }}
</p>
</div>
</div>
<div
v-if="importUsersResult.stats.errors.length > 0"
class="p-3 bg-destructive/10 rounded-lg"
>
<p class="font-medium text-destructive mb-2">
警告信息
</p>
<ul class="text-sm text-destructive space-y-1">
<li
v-for="(err, index) in importUsersResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<template #footer>
<Button @click="importUsersResultDialogOpen = false">
确定
</Button>
</template>
</Dialog>
</PageContainer>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Download, Upload } from 'lucide-vue-next'
import Button from '@/components/ui/button.vue'
import Input from '@/components/ui/input.vue'
import Label from '@/components/ui/label.vue'
@@ -389,9 +781,12 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
import SelectValue from '@/components/ui/select-value.vue'
import SelectContent from '@/components/ui/select-content.vue'
import SelectItem from '@/components/ui/select-item.vue'
import {
Dialog,
} from '@/components/ui'
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
import { useToast } from '@/composables/useToast'
import { adminApi } from '@/api/admin'
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
import { log } from '@/utils/logger'
const { success, error } = useToast()
@@ -403,8 +798,7 @@ interface SystemConfig {
// 用户注册
enable_registration: boolean
require_email_verification: boolean
// API Key 管理
api_key_expire_days: number
// 独立余额 Key 过期管理
auto_delete_expired_keys: boolean
// 日志记录
request_log_level: string
@@ -418,11 +812,34 @@ interface SystemConfig {
header_retention_days: number
log_retention_days: number
cleanup_batch_size: number
audit_log_retention_days: number
}
const loading = ref(false)
const logLevelSelectOpen = ref(false)
// 导出/导入相关
const exportLoading = ref(false)
const importLoading = ref(false)
const importDialogOpen = ref(false)
const importResultDialogOpen = ref(false)
const configFileInput = ref<HTMLInputElement | null>(null)
const importPreview = ref<ConfigExportData | null>(null)
const importResult = ref<ConfigImportResponse | null>(null)
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const mergeModeSelectOpen = ref(false)
// 用户数据导出/导入相关
const exportUsersLoading = ref(false)
const importUsersLoading = ref(false)
const importUsersDialogOpen = ref(false)
const importUsersResultDialogOpen = ref(false)
const usersFileInput = ref<HTMLInputElement | null>(null)
const importUsersPreview = ref<UsersExportData | null>(null)
const importUsersResult = ref<UsersImportResponse | null>(null)
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const usersMergeModeSelectOpen = ref(false)
const systemConfig = ref<SystemConfig>({
// 基础配置
default_user_quota_usd: 10.0,
@@ -430,8 +847,7 @@ const systemConfig = ref<SystemConfig>({
// 用户注册
enable_registration: false,
require_email_verification: false,
// API Key 管理
api_key_expire_days: 0,
// 独立余额 Key 过期管理
auto_delete_expired_keys: false,
// 日志记录
request_log_level: 'basic',
@@ -445,6 +861,7 @@ const systemConfig = ref<SystemConfig>({
header_retention_days: 90,
log_retention_days: 365,
cleanup_batch_size: 1000,
audit_log_retention_days: 30,
})
// 计算属性KB 和 字节 之间的转换
@@ -486,8 +903,7 @@ async function loadSystemConfig() {
// 用户注册
'enable_registration',
'require_email_verification',
// API Key 管理
'api_key_expire_days',
// 独立余额 Key 过期管理
'auto_delete_expired_keys',
// 日志记录
'request_log_level',
@@ -501,6 +917,7 @@ async function loadSystemConfig() {
'header_retention_days',
'log_retention_days',
'cleanup_batch_size',
'audit_log_retention_days',
]
for (const key of configs) {
@@ -545,12 +962,7 @@ async function saveSystemConfig() {
value: systemConfig.value.require_email_verification,
description: '是否需要邮箱验证'
},
// API Key 管理
{
key: 'api_key_expire_days',
value: systemConfig.value.api_key_expire_days,
description: 'API密钥过期天数'
},
// 独立余额 Key 过期管理
{
key: 'auto_delete_expired_keys',
value: systemConfig.value.auto_delete_expired_keys,
@@ -608,6 +1020,11 @@ async function saveSystemConfig() {
value: systemConfig.value.cleanup_batch_size,
description: '每批次清理的记录数'
},
{
key: 'audit_log_retention_days',
value: systemConfig.value.audit_log_retention_days,
description: '审计日志保留天数'
},
]
const promises = configItems.map(item =>
@@ -623,4 +1040,185 @@ async function saveSystemConfig() {
loading.value = false
}
}
// 导出配置
async function handleExportConfig() {
exportLoading.value = true
try {
const data = await adminApi.exportConfig()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('配置已导出')
} catch (err) {
error('导出配置失败')
log.error('导出配置失败:', err)
} finally {
exportLoading.value = false
}
}
// 触发文件选择
function triggerConfigFileSelect() {
configFileInput.value?.click()
}
// 文件大小限制 (10MB)
const MAX_FILE_SIZE = 10 * 1024 * 1024
// 处理文件选择
function handleConfigFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as ConfigExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importPreview.value = data
mergeMode.value = 'skip'
importDialogOpen.value = true
} catch (err) {
error('解析配置文件失败,请确保是有效的 JSON 文件')
log.error('解析配置文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入
async function confirmImport() {
if (!importPreview.value) return
importLoading.value = true
try {
const result = await adminApi.importConfig({
...importPreview.value,
merge_mode: mergeMode.value
})
importResult.value = result
importDialogOpen.value = false
mergeModeSelectOpen.value = false
importResultDialogOpen.value = true
success('配置导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入配置失败')
log.error('导入配置失败:', err)
} finally {
importLoading.value = false
}
}
// 导出用户数据
async function handleExportUsers() {
exportUsersLoading.value = true
try {
const data = await adminApi.exportUsers()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('用户数据已导出')
} catch (err) {
error('导出用户数据失败')
log.error('导出用户数据失败:', err)
} finally {
exportUsersLoading.value = false
}
}
// 触发用户数据文件选择
function triggerUsersFileSelect() {
usersFileInput.value?.click()
}
// 处理用户数据文件选择
function handleUsersFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as UsersExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importUsersPreview.value = data
usersMergeMode.value = 'skip'
importUsersDialogOpen.value = true
} catch (err) {
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
log.error('解析用户数据文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入用户数据
async function confirmImportUsers() {
if (!importUsersPreview.value) return
importUsersLoading.value = true
try {
const result = await adminApi.importUsers({
...importUsersPreview.value,
merge_mode: usersMergeMode.value
})
importUsersResult.value = result
importUsersDialogOpen.value = false
usersMergeModeSelectOpen.value = false
importUsersResultDialogOpen.value = true
success('用户数据导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入用户数据失败')
log.error('导入用户数据失败:', err)
} finally {
importUsersLoading.value = false
}
}
</script>

View File

@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
})
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
filtered = filtered.filter(
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
filtered = filtered.filter(u => {
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
if (filterRole.value !== 'all') {

View File

@@ -103,7 +103,7 @@
</div>
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
平均响应
@@ -114,7 +114,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-kraft/30">
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
错误率
@@ -128,7 +128,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
转移次数
@@ -142,7 +142,7 @@
v-if="costStats"
class="relative p-3 sm:p-4 border-manilla/40"
>
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
实际成本
@@ -180,7 +180,7 @@
</div>
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存命中率
@@ -191,7 +191,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-kraft/30">
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存读取
@@ -202,7 +202,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存创建
@@ -216,7 +216,7 @@
v-if="tokenBreakdown"
class="relative p-3 sm:p-4 border-manilla/40"
>
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
总Token
@@ -254,16 +254,16 @@
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
<div
v-if="loadingAnnouncements"
class="py-8 text-center"
class="flex-1 flex items-center justify-center"
>
<Loader2 class="h-5 w-5 animate-spin mx-auto text-muted-foreground" />
<Loader2 class="h-5 w-5 animate-spin text-muted-foreground" />
</div>
<div
v-else-if="announcements.length === 0"
class="py-8 text-center"
class="flex-1 flex flex-col items-center justify-center"
>
<Bell class="h-8 w-8 mx-auto text-muted-foreground/40" />
<Bell class="h-8 w-8 text-muted-foreground/40" />
<p class="mt-2 text-xs text-muted-foreground">
暂无公告
</p>
@@ -793,9 +793,8 @@ const statCardGlows = [
'bg-kraft/30'
]
const getStatIconColor = (index: number): string => {
const colors = ['text-book-cloth', 'text-kraft', 'text-book-cloth', 'text-kraft']
return colors[index % colors.length]
const getStatIconColor = (_index: number): string => {
return 'text-muted-foreground'
}
// 统计数据

View File

@@ -474,13 +474,13 @@ async function toggleCapability(modelName: string, capName: string) {
const filteredModels = computed(() => {
let result = models.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选

View File

@@ -62,6 +62,7 @@
<Button
type="submit"
:disabled="savingProfile"
class="shadow-none hover:shadow-none"
>
{{ savingProfile ? '保存中...' : '保存修改' }}
</Button>
@@ -107,6 +108,7 @@
<Button
type="submit"
:disabled="changingPassword"
class="shadow-none hover:shadow-none"
>
{{ changingPassword ? '修改中...' : '修改密码' }}
</Button>
@@ -320,6 +322,7 @@
import { ref, onMounted } from 'vue'
import { useAuthStore } from '@/stores/auth'
import { meApi, type Profile } from '@/api/me'
import { useDarkMode, type ThemeMode } from '@/composables/useDarkMode'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue'
@@ -338,6 +341,7 @@ import { log } from '@/utils/logger'
const authStore = useAuthStore()
const { success, error: showError } = useToast()
const { setThemeMode } = useDarkMode()
const profile = ref<Profile | null>(null)
@@ -375,20 +379,8 @@ function handleThemeChange(value: string) {
themeSelectOpen.value = false
updatePreferences()
// 应用主题
if (value === 'dark') {
document.documentElement.classList.add('dark')
} else if (value === 'light') {
document.documentElement.classList.remove('dark')
} else {
// system: 跟随系统
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches
if (prefersDark) {
document.documentElement.classList.add('dark')
} else {
document.documentElement.classList.remove('dark')
}
}
// 使用 useDarkMode 统一切换主题
setThemeMode(value as ThemeMode)
}
function handleLanguageChange(value: string) {
@@ -418,10 +410,16 @@ async function loadProfile() {
async function loadPreferences() {
try {
const prefs = await meApi.getPreferences()
// 主题以本地 localStorage 为准useDarkMode 在应用启动时已初始化)
// 这样可以避免刷新页面时主题被服务端旧值覆盖
const { themeMode: currentThemeMode } = useDarkMode()
const localTheme = currentThemeMode.value
preferencesForm.value = {
avatar_url: prefs.avatar_url || '',
bio: prefs.bio || '',
theme: prefs.theme || 'light',
theme: localTheme, // 使用本地主题,而非服务端返回值
language: prefs.language || 'zh-CN',
timezone: prefs.timezone || 'Asia/Shanghai',
notifications: {
@@ -431,11 +429,12 @@ async function loadPreferences() {
}
}
// 应用主题
if (preferencesForm.value.theme === 'dark') {
document.documentElement.classList.add('dark')
} else if (preferencesForm.value.theme === 'light') {
document.documentElement.classList.remove('dark')
// 如果本地主题和服务端不一致,同步到服务端(静默更新,不提示用户)
const serverTheme = prefs.theme || 'light'
if (localTheme !== serverTheme) {
meApi.updatePreferences({ theme: localTheme }).catch(() => {
// 静默失败,不影响用户体验
})
}
} catch (error) {
log.error('加载偏好设置失败:', error)

View File

@@ -3,10 +3,8 @@
A proxy server that enables AI models to work with multiple API providers.
"""
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# 注意: dotenv 加载已统一移至 src/config/settings.py
# 不要在此处重复加载
try:
from src._version import __version__

View File

@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_query import router as provider_query_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
router.include_router(provider_query_router)
__all__ = ["router"]

View File

@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
allowed_providers=self.key_data.allowed_providers,
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit or 100,
rate_limit=self.key_data.rate_limit, # None 表示不限制
expire_days=self.key_data.expire_days,
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key

View File

@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
"""对代理配置中的密码进行脱敏处理"""
if not proxy_config:
return None
masked = dict(proxy_config)
if masked.get("password"):
masked["password"] = "***"
return masked
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
"proxy": mask_proxy_password(endpoint.proxy),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
created_at=now,
updated_at=now,
)
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
proxy=mask_proxy_password(new_endpoint.proxy),
total_keys=0,
active_keys=0,
)
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
proxy=mask_proxy_password(endpoint_obj.proxy),
total_keys=total_keys,
active_keys=active_keys,
)
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
if "proxy" in update_data:
if update_data["proxy"] is not None:
new_proxy = dict(update_data["proxy"])
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
if "password" not in new_proxy and endpoint.proxy:
old_password = endpoint.proxy.get("password")
if old_password:
new_proxy["password"] = old_password
update_data["proxy"] = new_proxy
# proxy 为 None 时保留,用于清除代理配置
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
proxy=mask_proxy_password(endpoint.proxy),
total_keys=total_keys,
active_keys=active_keys,
)

View File

@@ -21,7 +21,8 @@ from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
from src.services.cache.affinity_manager import get_affinity_manager
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
from src.services.cache.aware_scheduler import CacheAwareScheduler, get_cache_aware_scheduler
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
pipeline = ApiRequestPipeline()
@@ -185,6 +186,30 @@ async def clear_user_cache(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/affinity/{affinity_key}/{endpoint_id}/{model_id}/{api_format}")
async def clear_single_affinity(
affinity_key: str,
endpoint_id: str,
model_id: str,
api_format: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
Clear a single cache affinity entry
Parameters:
- affinity_key: API Key ID
- endpoint_id: Endpoint ID
- model_id: Model ID (GlobalModel ID)
- api_format: API format (claude/openai)
"""
adapter = AdminClearSingleAffinityAdapter(
affinity_key=affinity_key, endpoint_id=endpoint_id, model_id=model_id, api_format=api_format
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("")
async def clear_all_cache(
request: Request,
@@ -250,7 +275,22 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
# 读取系统配置,确保监控接口与编排器使用一致的模式
priority_mode = SystemConfigService.get_config(
context.db,
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
context.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
scheduler = await get_cache_aware_scheduler(
redis_client,
priority_mode=priority_mode,
scheduling_mode=scheduling_mode,
)
stats = await scheduler.get_stats()
logger.info("缓存统计信息查询成功")
context.add_audit_metadata(
@@ -270,7 +310,22 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
# 读取系统配置,确保监控接口与编排器使用一致的模式
priority_mode = SystemConfigService.get_config(
context.db,
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
context.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
scheduler = await get_cache_aware_scheduler(
redis_client,
priority_mode=priority_mode,
scheduling_mode=scheduling_mode,
)
stats = await scheduler.get_stats()
payload = self._format_prometheus(stats)
context.add_audit_metadata(
@@ -624,6 +679,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
"key_name": key.name if key else None,
"key_prefix": provider_key_masked,
"rate_multiplier": key.rate_multiplier if key else 1.0,
"global_model_id": affinity.get("model_name"), # 原始的 global_model_id
"model_name": (
global_model_map.get(affinity.get("model_name")).name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
@@ -786,6 +842,65 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearSingleAffinityAdapter(AdminApiAdapter):
affinity_key: str
endpoint_id: str
model_id: str
api_format: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 直接获取指定的亲和性记录(无需遍历全部)
existing_affinity = await affinity_mgr.get_affinity(
self.affinity_key, self.api_format, self.model_id
)
if not existing_affinity:
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
# 验证 endpoint_id 是否匹配
if existing_affinity.endpoint_id != self.endpoint_id:
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
# 失效单条记录
await affinity_mgr.invalidate_affinity(
self.affinity_key, self.api_format, self.model_id, endpoint_id=self.endpoint_id
)
# 获取用于日志的信息
api_key = db.query(ApiKey).filter(ApiKey.id == self.affinity_key).first()
api_key_name = api_key.name if api_key else None
logger.info(
f"已清除单条缓存亲和性: affinity_key={self.affinity_key[:8]}..., "
f"endpoint_id={self.endpoint_id[:8]}..., model_id={self.model_id[:8]}..."
)
context.add_audit_metadata(
action="cache_clear_single",
affinity_key=self.affinity_key,
endpoint_id=self.endpoint_id,
model_id=self.model_id,
)
return {
"status": "ok",
"message": f"已清除缓存亲和性: {api_key_name or self.affinity_key[:8]}",
"affinity_key": self.affinity_key,
"endpoint_id": self.endpoint_id,
"model_id": self.model_id,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除单条缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class RequestTraceResponse(BaseModel):

View File

@@ -1,46 +1,28 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
用于查询提供商的模型列表等信息
"""
from datetime import datetime
import asyncio
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
import httpx
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.models.database import Provider, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
async def _fetch_openai_models(
client: httpx.AsyncClient,
base_url: str,
api_key: str,
api_format: str,
extra_headers: Optional[dict] = None,
) -> tuple[list, Optional[str]]:
"""获取 OpenAI 格式的模型列表
Returns:
适配器列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
registry = get_query_registry()
adapters = registry.list_adapters()
headers = {"Authorization": f"Bearer {api_key}"}
if extra_headers:
# 防止 extra_headers 覆盖 Authorization
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
headers.update(safe_headers)
return {"success": True, "data": adapters}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
# 记录详细的错误信息
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
async def _fetch_claude_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Claude 格式的模型列表
Returns:
支持的查询能力列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
headers = {
"x-api-key": api_key,
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg
Args:
request: 查询请求
async def _fetch_gemini_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Gemini 格式的模型列表
Returns:
余额信息
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
try:
response = await client.get(models_url)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": api_format,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
Args:
request: 查询请求
Returns:
模型列表
所有端点的模型列表(合并)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break
if api_key_value:
if endpoint_configs:
break
if not api_key_value:
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not endpoint.is_active or not endpoint.api_keys:
continue
if not api_key_value:
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config["extra_headers"]
try:
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
elif api_format in ["GEMINI", "GEMINI_CLI"]:
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
else:
return await _fetch_openai_models(
client, base_url, api_key_value, api_format, extra_headers
)
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
# 按 model id 去重(保留第一个)
seen_ids: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
if model_id and model_id not in seen_ids:
seen_ids.add(model_id)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
return {
"success": query_result.success,
"data": query_result.to_dict(),
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.models_service import invalidate_models_list_cache
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
@@ -419,4 +420,8 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
)
# 清除 /v1/models 列表缓存
if success:
await invalidate_models_list_cache()
return BatchAssignModelsToProviderResponse(success=success, errors=errors)

View File

@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config/export")
async def export_config(request: Request, db: Session = Depends(get_db)):
"""导出提供商和模型配置(管理员)"""
adapter = AdminExportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/config/import")
async def import_config(request: Request, db: Session = Depends(get_db)):
"""导入提供商和模型配置(管理员)"""
adapter = AdminImportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/users/export")
async def export_users(request: Request, db: Session = Depends(get_db)):
"""导出用户数据(管理员)"""
adapter = AdminExportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/users/import")
async def import_users(request: Request, db: Session = Depends(get_db)):
"""导入用户数据(管理员)"""
adapter = AdminImportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
)
return {"formats": formats}
class AdminExportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出提供商和模型配置(解密数据)"""
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
db = context.db
# 导出 GlobalModels
global_models = db.query(GlobalModel).all()
global_models_data = []
for gm in global_models:
global_models_data.append(
{
"name": gm.name,
"display_name": gm.display_name,
"default_price_per_request": gm.default_price_per_request,
"default_tiered_pricing": gm.default_tiered_pricing,
"supported_capabilities": gm.supported_capabilities,
"config": gm.config,
"is_active": gm.is_active,
}
)
# 导出 Providers 及其关联数据
providers = db.query(Provider).all()
providers_data = []
for provider in providers:
# 导出 Endpoints
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == provider.id)
.all()
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
}
)
# 导出 Provider Models
models = db.query(Model).filter(Model.provider_id == provider.id).all()
models_data = []
for model in models:
# 获取关联的 GlobalModel 名称
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
)
models_data.append(
{
"global_model_name": global_model.name if global_model else None,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": model.provider_model_aliases,
"price_per_request": model.price_per_request,
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": model.supports_image_generation,
"is_active": model.is_active,
"config": model.config,
}
)
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"rate_limit": provider.rate_limit,
"concurrent_limit": provider.concurrent_limit,
"config": provider.config,
"endpoints": endpoints_data,
"models": models_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
}
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
class AdminImportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入提供商和模型配置"""
import uuid
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.core.enums import ProviderBillingType
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
global_models_data = payload.get("global_models", [])
providers_data = payload.get("providers", [])
stats = {
"global_models": {"created": 0, "updated": 0, "skipped": 0},
"providers": {"created": 0, "updated": 0, "skipped": 0},
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
"keys": {"created": 0, "updated": 0, "skipped": 0},
"models": {"created": 0, "updated": 0, "skipped": 0},
"errors": [],
}
try:
# 导入 GlobalModels
global_model_map = {} # name -> id 映射
for gm_data in global_models_data:
existing = (
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
)
if existing:
global_model_map[gm_data["name"]] = existing.id
if merge_mode == "skip":
stats["global_models"]["skipped"] += 1
continue
elif merge_mode == "error":
raise InvalidRequestException(
f"GlobalModel '{gm_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing.display_name = gm_data.get(
"display_name", existing.display_name
)
existing.default_price_per_request = gm_data.get(
"default_price_per_request"
)
existing.default_tiered_pricing = gm_data.get(
"default_tiered_pricing", existing.default_tiered_pricing
)
existing.supported_capabilities = gm_data.get(
"supported_capabilities"
)
existing.config = gm_data.get("config")
existing.is_active = gm_data.get("is_active", True)
existing.updated_at = datetime.now(timezone.utc)
stats["global_models"]["updated"] += 1
else:
# 创建新记录
new_gm = GlobalModel(
id=str(uuid.uuid4()),
name=gm_data["name"],
display_name=gm_data.get("display_name", gm_data["name"]),
default_price_per_request=gm_data.get("default_price_per_request"),
default_tiered_pricing=gm_data.get(
"default_tiered_pricing",
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
),
supported_capabilities=gm_data.get("supported_capabilities"),
config=gm_data.get("config"),
is_active=gm_data.get("is_active", True),
)
db.add(new_gm)
db.flush()
global_model_map[gm_data["name"]] = new_gm.id
stats["global_models"]["created"] += 1
# 导入 Providers
for prov_data in providers_data:
existing_provider = (
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
)
if existing_provider:
provider_id = existing_provider.id
if merge_mode == "skip":
stats["providers"]["skipped"] += 1
# 仍然需要处理 endpoints 和 models如果存在
elif merge_mode == "error":
raise InvalidRequestException(
f"Provider '{prov_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
if prov_data.get("billing_type"):
existing_provider.billing_type = ProviderBillingType(
prov_data["billing_type"]
)
existing_provider.monthly_quota_usd = prov_data.get(
"monthly_quota_usd"
)
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
existing_provider.is_active = prov_data.get("is_active", True)
existing_provider.rate_limit = prov_data.get("rate_limit")
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
else:
# 创建新 Provider
billing_type = ProviderBillingType.PAY_AS_YOU_GO
if prov_data.get("billing_type"):
billing_type = ProviderBillingType(prov_data["billing_type"])
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
rate_limit=prov_data.get("rate_limit"),
concurrent_limit=prov_data.get("concurrent_limit"),
config=prov_data.get("config"),
)
db.add(new_provider)
db.flush()
provider_id = new_provider.id
stats["providers"]["created"] += 1
# 导入 Endpoints
for ep_data in prov_data.get("endpoints", []):
existing_ep = (
db.query(ProviderEndpoint)
.filter(
ProviderEndpoint.provider_id == provider_id,
ProviderEndpoint.api_format == ep_data["api_format"],
)
.first()
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_ep.base_url = ep_data.get(
"base_url", existing_ep.base_url
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
new_ep = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_format=ep_data["api_format"],
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
stats["keys"]["created"] += 1
# 导入 Models
for model_data in prov_data.get("models", []):
global_model_name = model_data.get("global_model_name")
if not global_model_name:
stats["errors"].append(
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
)
continue
global_model_id = global_model_map.get(global_model_name)
if not global_model_id:
# 尝试从数据库查找
existing_gm = (
db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name)
.first()
)
if existing_gm:
global_model_id = existing_gm.id
else:
stats["errors"].append(
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
)
continue
existing_model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.provider_model_name == model_data["provider_model_name"],
)
.first()
)
if existing_model:
if merge_mode == "skip":
stats["models"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_model.global_model_id = global_model_id
existing_model.provider_model_aliases = model_data.get(
"provider_model_aliases"
)
existing_model.price_per_request = model_data.get(
"price_per_request"
)
existing_model.tiered_pricing = model_data.get(
"tiered_pricing"
)
existing_model.supports_vision = model_data.get(
"supports_vision"
)
existing_model.supports_function_calling = model_data.get(
"supports_function_calling"
)
existing_model.supports_streaming = model_data.get(
"supports_streaming"
)
existing_model.supports_extended_thinking = model_data.get(
"supports_extended_thinking"
)
existing_model.supports_image_generation = model_data.get(
"supports_image_generation"
)
existing_model.is_active = model_data.get("is_active", True)
existing_model.config = model_data.get("config")
existing_model.updated_at = datetime.now(timezone.utc)
stats["models"]["updated"] += 1
else:
new_model = Model(
id=str(uuid.uuid4()),
provider_id=provider_id,
global_model_id=global_model_id,
provider_model_name=model_data["provider_model_name"],
provider_model_aliases=model_data.get(
"provider_model_aliases"
),
price_per_request=model_data.get("price_per_request"),
tiered_pricing=model_data.get("tiered_pricing"),
supports_vision=model_data.get("supports_vision"),
supports_function_calling=model_data.get(
"supports_function_calling"
),
supports_streaming=model_data.get("supports_streaming"),
supports_extended_thinking=model_data.get(
"supports_extended_thinking"
),
supports_image_generation=model_data.get(
"supports_image_generation"
),
is_active=model_data.get("is_active", True),
config=model_data.get("config"),
)
db.add(new_model)
stats["models"]["created"] += 1
db.commit()
# 失效缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.clear_all_caches()
return {
"message": "配置导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")
class AdminExportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出用户数据(保留加密数据,排除管理员)"""
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
db = context.db
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
users_data.append(
{
"email": user.email,
"username": user.username,
"password_hash": user.password_hash,
"role": user.role.value if user.role else "user",
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"model_capability_settings": user.model_capability_settings,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
"is_active": user.is_active,
"api_keys": api_keys_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
}
class AdminImportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入用户数据"""
import uuid
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"errors": [],
}
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
role_str = str(user_data.get("role", "")).lower()
if role_str == "admin":
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
stats["users"]["skipped"] += 1
continue
existing_user = (
db.query(User).filter(User.email == user_data["email"]).first()
)
if existing_user:
user_id = existing_user.id
if merge_mode == "skip":
stats["users"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"用户 '{user_data['email']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有用户
existing_user.username = user_data.get(
"username", existing_user.username
)
if user_data.get("password_hash"):
existing_user.password_hash = user_data["password_hash"]
if user_data.get("role"):
existing_user.role = UserRole(user_data["role"])
existing_user.allowed_providers = user_data.get("allowed_providers")
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
existing_user.allowed_models = user_data.get("allowed_models")
existing_user.model_capability_settings = user_data.get(
"model_capability_settings"
)
existing_user.quota_usd = user_data.get("quota_usd")
existing_user.used_usd = user_data.get("used_usd", 0.0)
existing_user.total_usd = user_data.get("total_usd", 0.0)
existing_user.is_active = user_data.get("is_active", True)
existing_user.updated_at = datetime.now(timezone.utc)
stats["users"]["updated"] += 1
else:
# 创建新用户
role = UserRole.USER
if user_data.get("role"):
role = UserRole(user_data["role"])
new_user = User(
id=str(uuid.uuid4()),
email=user_data["email"],
username=user_data.get("username", user_data["email"].split("@")[0]),
password_hash=user_data.get("password_hash", ""),
role=role,
allowed_providers=user_data.get("allowed_providers"),
allowed_endpoints=user_data.get("allowed_endpoints"),
allowed_models=user_data.get("allowed_models"),
model_capability_settings=user_data.get("model_capability_settings"),
quota_usd=user_data.get("quota_usd"),
used_usd=user_data.get("used_usd", 0.0),
total_usd=user_data.get("total_usd", 0.0),
is_active=user_data.get("is_active", True),
)
db.add(new_user)
db.flush()
user_id = new_user.id
stats["users"]["created"] += 1
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
db.add(new_key)
stats["api_keys"]["created"] += 1
db.commit()
return {
"message": "用户数据导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")

View File

@@ -140,9 +140,9 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
if not authorization or not authorization.lower().startswith("bearer "):
return None
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:
return None

View File

@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
from ..models.database import SystemConfig
from src.models.database import SystemConfig
db = context.db
payload = context.ensure_json_body()

View File

@@ -55,6 +55,23 @@ async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"])
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
async def invalidate_models_list_cache() -> None:
"""
清除所有 /v1/models 列表缓存
在模型创建、更新、删除时调用,确保模型列表实时更新
"""
# 清除所有格式的缓存
all_formats = ["CLAUDE", "OPENAI", "GEMINI"]
for fmt in all_formats:
cache_key = f"{_CACHE_KEY_PREFIX}:{fmt}"
try:
await CacheService.delete(cache_key)
logger.debug(f"[ModelsService] 已清除缓存: {cache_key}")
except Exception as e:
logger.warning(f"[ModelsService] 清除缓存失败 {cache_key}: {e}")
@dataclass
class ModelInfo:
"""统一的模型信息结构"""

View File

@@ -5,13 +5,12 @@ from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
@@ -178,9 +177,9 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少管理员凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -205,9 +204,9 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少用户凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
"""记录审计事件
事务策略:复用请求级 Session不单独提交。
审计记录随主事务一起提交,由中间件统一管理。
"""
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
if context.db is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
# 复用请求级 Session不创建新的连接
# 审计记录随主事务一起提交,由中间件统一管理
self.audit_service.log_event(
db=audit_session,
db=context.db,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
# 审计失败不应影响主请求,仅记录警告
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,

View File

@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 需要转回业务时区再取日期,才能与日期序列匹配
def _to_business_date_str(value: datetime) -> str:
if value.tzinfo is None:
value_utc = value.replace(tzinfo=timezone.utc)
else:
value_utc = value.astimezone(timezone.utc)
return value_utc.astimezone(app_tz).date().isoformat()
stats_map = {
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
_to_business_date_str(stat.date): {
"requests": stat.total_requests,
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
"cost": stat.total_cost,
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
"unique_providers": today_unique_providers,
"fallback_count": today_fallback_count,
}
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
yesterday_date = today_local.date() - timedelta(days=1)
historical_end = min(end_date_local.date(), yesterday_date)
missing_dates: list[str] = []
cursor = start_date_local.date()
while cursor <= historical_end:
date_str = cursor.isoformat()
if date_str not in stats_map:
missing_dates.append(date_str)
cursor += timedelta(days=1)
if missing_dates:
for date_str in missing_dates[-7:]:
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
stats_map[date_str] = {
"requests": computed["total_requests"],
"tokens": (
computed["input_tokens"]
+ computed["output_tokens"]
+ computed["cache_creation_tokens"]
+ computed["cache_read_tokens"]
),
"cost": computed["total_cost"],
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
if computed["avg_response_time_ms"]
else 0,
"unique_models": computed["unique_models"],
"unique_providers": computed["unique_providers"],
"fallback_count": computed["fallback_count"],
}
else:
# 普通用户:仍需实时查询(用户级预聚合可选)
query = db.query(Usage).filter(

View File

@@ -28,7 +28,7 @@
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
@@ -399,6 +402,41 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""
import asyncio
from src.database.database import get_db
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
async def _do_update() -> None:
try:
db_gen = get_db()
db = next(db_gen)
try:
UsageService.update_usage_status(
db=db,
request_id=target_request_id,
status="streaming",
provider=provider,
target_model=target_model,
)
finally:
db.close()
except Exception as e:
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
@@ -411,9 +449,10 @@ class BaseMessageHandler:
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
UpstreamClientException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:

View File

@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)

View File

@@ -88,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
response_normalizer: Optional[Any] = None,
enable_response_normalization: bool = False,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
@@ -106,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
self._parser: Optional[ResponseParser] = None
self._request_builder = PassthroughRequestBuilder()
self.response_normalizer = response_normalizer
self.enable_response_normalization = enable_response_normalization
@property
def parser(self) -> ResponseParser:
@@ -266,8 +262,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name
@@ -294,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建更新状态的回调闭包(可以访问 ctx
def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
on_streaming_start=update_streaming_status,
)
# 定义请求函数
@@ -463,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
pool=config.http_pool_timeout,
)
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -630,11 +639,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
f"模型={model} -> {mapped_model or '无映射'}")
logger.debug(f" [{self.request_id}] 请求URL: {url}")
logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code
@@ -649,10 +664,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
response_headers=response_headers,
)
elif resp.status_code >= 500:
raise ProviderNotAvailableException(f"提供商服务不可用: {provider.name}")
elif resp.status_code != 200:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}")
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}"
f"提供商服务不可用: {provider.name}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}")
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
response_json = resp.json()

View File

@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name
@@ -451,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -523,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
@@ -807,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks:
@@ -1105,8 +1114,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
async for chunk in stream_generator:
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
raise
except httpx.TimeoutException as e:
ctx.status_code = 504
@@ -1416,10 +1427,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code

View File

@@ -0,0 +1,274 @@
"""
流式内容提取器 - 策略模式实现
为不同 API 格式OpenAI、Claude、Gemini提供内容提取和 chunk 构造的抽象。
StreamSmoother 使用这些提取器来处理不同格式的 SSE 事件。
"""
import copy
import json
from abc import ABC, abstractmethod
from typing import Optional
class ContentExtractor(ABC):
"""
流式内容提取器抽象基类
定义从 SSE 事件中提取文本内容和构造新 chunk 的接口。
每种 API 格式OpenAI、Claude、Gemini需要实现自己的提取器。
"""
@abstractmethod
def extract_content(self, data: dict) -> Optional[str]:
"""
从 SSE 数据中提取可拆分的文本内容
Args:
data: 解析后的 JSON 数据
Returns:
提取的文本内容,如果无法提取则返回 None
"""
pass
@abstractmethod
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
"""
使用新内容构造 SSE chunk
Args:
original_data: 原始 JSON 数据
new_content: 新的文本内容
event_type: SSE 事件类型(某些格式需要)
is_first: 是否是第一个 chunk用于保留 role 等字段)
Returns:
编码后的 SSE 字节数据
"""
pass
class OpenAIContentExtractor(ContentExtractor):
"""
OpenAI 格式内容提取器
处理 OpenAI Chat Completions API 的流式响应格式:
- 数据结构: choices[0].delta.content
- 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
choices = data.get("choices")
if not isinstance(choices, list) or len(choices) != 1:
return None
first_choice = choices[0]
if not isinstance(first_choice, dict):
return None
delta = first_choice.get("delta")
if not isinstance(delta, dict):
return None
content = delta.get("content")
if not isinstance(content, str):
return None
# 只有 delta 仅包含 role/content 时才允许拆分
# 避免破坏 tool_calls、function_call 等复杂结构
allowed_keys = {"role", "content"}
if not all(key in allowed_keys for key in delta.keys()):
return None
return content
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = original_data.copy()
if "choices" in new_data and new_data["choices"]:
new_choices = []
for choice in new_data["choices"]:
new_choice = choice.copy()
if "delta" in new_choice:
new_delta = {}
# 只有第一个 chunk 保留 role
if is_first and "role" in new_choice["delta"]:
new_delta["role"] = new_choice["delta"]["role"]
new_delta["content"] = new_content
new_choice["delta"] = new_delta
new_choices.append(new_choice)
new_data["choices"] = new_choices
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
class ClaudeContentExtractor(ContentExtractor):
"""
Claude 格式内容提取器
处理 Claude Messages API 的流式响应格式:
- 事件类型: content_block_delta
- 数据结构: delta.type=text_delta, delta.text
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
# 检查事件类型
if data.get("type") != "content_block_delta":
return None
delta = data.get("delta", {})
if not isinstance(delta, dict):
return None
# 检查 delta 类型
if delta.get("type") != "text_delta":
return None
text = delta.get("text")
if not isinstance(text, str):
return None
return text
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = original_data.copy()
if "delta" in new_data:
new_delta = new_data["delta"].copy()
new_delta["text"] = new_content
new_data["delta"] = new_delta
# Claude 格式需要 event: 前缀
event_name = event_type or "content_block_delta"
return f"event: {event_name}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
"utf-8"
)
class GeminiContentExtractor(ContentExtractor):
"""
Gemini 格式内容提取器
处理 Gemini API 的流式响应格式:
- 数据结构: candidates[0].content.parts[0].text
- 只有纯文本块才拆分
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
candidates = data.get("candidates")
if not isinstance(candidates, list) or len(candidates) != 1:
return None
first_candidate = candidates[0]
if not isinstance(first_candidate, dict):
return None
content = first_candidate.get("content", {})
if not isinstance(content, dict):
return None
parts = content.get("parts", [])
if not isinstance(parts, list) or len(parts) != 1:
return None
first_part = parts[0]
if not isinstance(first_part, dict):
return None
text = first_part.get("text")
# 只有纯文本块(只有 text 字段)才拆分
if not isinstance(text, str) or len(first_part) != 1:
return None
return text
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = copy.deepcopy(original_data)
if "candidates" in new_data and new_data["candidates"]:
first_candidate = new_data["candidates"][0]
if "content" in first_candidate:
content = first_candidate["content"]
if "parts" in content and content["parts"]:
content["parts"][0]["text"] = new_content
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
# 提取器注册表
_EXTRACTORS: dict[str, type[ContentExtractor]] = {
"openai": OpenAIContentExtractor,
"claude": ClaudeContentExtractor,
"gemini": GeminiContentExtractor,
}
def get_extractor(format_name: str) -> Optional[ContentExtractor]:
"""
根据格式名获取对应的内容提取器实例
Args:
format_name: 格式名称openai, claude, gemini
Returns:
对应的提取器实例,如果格式不支持则返回 None
"""
extractor_class = _EXTRACTORS.get(format_name.lower())
if extractor_class:
return extractor_class()
return None
def register_extractor(format_name: str, extractor_class: type[ContentExtractor]) -> None:
"""
注册新的内容提取器
Args:
format_name: 格式名称
extractor_class: 提取器类
"""
_EXTRACTORS[format_name.lower()] = extractor_class
def get_extractor_formats() -> list[str]:
"""
获取所有已注册的格式名称列表
Returns:
格式名称列表
"""
return list(_EXTRACTORS.keys())

View File

@@ -6,16 +6,22 @@
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
5. 流式平滑输出
"""
import asyncio
import codecs
import json
import time
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.content_extractors import (
ContentExtractor,
get_extractor,
get_extractor_formats,
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
@@ -25,11 +31,20 @@ from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamSmoothingConfig:
"""流式平滑输出配置"""
enabled: bool = False
chunk_size: int = 20
delay_ms: int = 8
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测响应生成。
负责处理 SSE 流的解析、错误检测响应生成和平滑输出
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
@@ -40,6 +55,7 @@ class StreamProcessor:
on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
smoothing_config: Optional[StreamSmoothingConfig] = None,
):
"""
初始化流处理器
@@ -48,11 +64,17 @@ class StreamProcessor:
request_id: 请求 ID用于日志
default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态)
collect_text: 是否收集文本内容
smoothing_config: 流式平滑输出配置
"""
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
# 内容提取器缓存
self._extractors: dict[str, ContentExtractor] = {}
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
@@ -127,6 +149,13 @@ class StreamProcessor:
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
# 检查 OpenAI 格式的 finish_reason
choices = data.get("choices", [])
if choices and isinstance(choices, list) and len(choices) > 0:
finish_reason = choices[0].get("finish_reason")
if finish_reason is not None:
ctx.has_completion = True
async def prefetch_and_check_error(
self,
byte_iterator: Any,
@@ -369,7 +398,7 @@ class StreamProcessor:
sse_parser: SSE 解析器
line: 原始行数据
"""
# SSEEventParser 以去掉换行符的单行文本作为输入;这里统一剔除 CR/LF
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
@@ -400,32 +429,201 @@ class StreamProcessor:
响应数据块
"""
try:
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
next_disconnect_check_at = 0.0
disconnect_check_interval_s = 0.25
# 使用后台任务检查断连,完全不阻塞流式传输
disconnected = False
async for chunk in stream_generator:
now = time.monotonic()
if now >= next_disconnect_check_at:
next_disconnect_check_at = now + disconnect_check_interval_s
async def check_disconnect_background() -> None:
nonlocal disconnected
while not disconnected and not ctx.has_completion:
await asyncio.sleep(0.5)
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
disconnected = True
break
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
# 启动后台检查任务
check_task = asyncio.create_task(check_disconnect_background())
try:
async for chunk in stream_generator:
if disconnected:
# 如果响应已完成,客户端断开不算失败
if ctx.has_completion:
logger.info(
f"ID:{self.request_id} | Client disconnected after completion"
)
else:
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499
ctx.error_message = "client_disconnected"
break
yield chunk
finally:
check_task.cancel()
try:
await check_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def create_smoothed_stream(
self,
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""
创建平滑输出的流生成器
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
否则直接透传原始流。
Args:
stream_generator: 原始流生成器
Yields:
平滑处理后的响应数据块
"""
if not self.smoothing_config.enabled:
# 未启用平滑输出,直接透传
async for chunk in stream_generator:
yield chunk
return
# 启用平滑输出
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
# 按双换行分割 SSE 事件(标准 SSE 格式)
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
# 解析事件块
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
# 没有 data 行,直接透传
if data_str is None:
yield event_block + b"\n\n"
continue
# [DONE] 直接透传
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
# 尝试解析 JSON
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
# 检测格式并提取内容
content, extractor = self._detect_format_and_extract(data)
# 只有内容长度大于 1 才需要平滑处理
if content and len(content) > 1 and extractor:
# 获取配置的延迟
delay_seconds = self._calculate_delay()
# 拆分内容
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
# 使用提取器创建新 chunk
sse_chunk = extractor.create_chunk(
data,
sub_content,
event_type=event_type,
is_first=is_first,
)
yield sse_chunk
# 除了最后一个块,其他块之间加延迟
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
# 不需要拆分,直接透传
yield event_block + b"\n\n"
if content:
is_first_content = False
# 处理剩余数据
if buffer:
yield buffer
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
"""获取或创建格式对应的提取器(带缓存)"""
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
"""
检测数据格式并提取内容
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
Returns:
(content, extractor): 提取的内容和对应的提取器
"""
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self) -> float:
"""获取配置的延迟(秒)"""
return self.smoothing_config.delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
"""
按块拆分文本
"""
chunk_size = self.smoothing_config.chunk_size
text_length = len(content)
if text_length <= chunk_size:
return [content]
# 按块拆分
chunks = []
for i in range(0, text_length, chunk_size):
chunks.append(content[i : i + chunk_size])
return chunks
async def _cleanup(
self,
response_ctx: Any,
@@ -440,3 +638,128 @@ class StreamProcessor:
await http_client.aclose()
except Exception:
pass
async def create_smoothed_stream(
stream_generator: AsyncGenerator[bytes, None],
chunk_size: int = 20,
delay_ms: int = 8,
) -> AsyncGenerator[bytes, None]:
"""
独立的平滑流生成函数
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
Args:
stream_generator: 原始流生成器
chunk_size: 每块字符数
delay_ms: 每块之间的延迟毫秒数
Yields:
平滑处理后的响应数据块
"""
processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
async for chunk in processor.smooth(stream_generator):
yield chunk
class _LightweightSmoother:
"""
轻量级平滑处理器
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
"""
def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
self.chunk_size = chunk_size
self.delay_ms = delay_ms
self._extractors: dict[str, ContentExtractor] = {}
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self) -> float:
return self.delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
text_length = len(content)
if text_length <= self.chunk_size:
return [content]
return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
async def smooth(
self, stream_generator: AsyncGenerator[bytes, None]
) -> AsyncGenerator[bytes, None]:
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
if data_str is None:
yield event_block + b"\n\n"
continue
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
content, extractor = self._detect_format_and_extract(data)
if content and len(content) > 1 and extractor:
delay_seconds = self._calculate_delay()
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
sse_chunk = extractor.create_chunk(
data, sub_content, event_type=event_type, is_first=is_first
)
yield sse_chunk
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
yield event_block + b"\n\n"
if content:
is_first_content = False
if buffer:
yield buffer

View File

@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return result
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
TODO: 如果需要,实现响应规范化逻辑
"""
# 可选:使用 response_normalizer 进行规范化
# if (
# self.response_normalizer
# and self.response_normalizer.should_normalize(response)
# ):
# return self.response_normalizer.normalize_gemini_response(
# response_data=response,
# request_id=self.request_id,
# strict=False,
# )
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_openai_response(
response_data=response,
request_id=self.request_id,
strict=False,
)
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -5,12 +5,55 @@
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from urllib.parse import quote, urlparse
import httpx
from src.core.logger import logger
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
"""
根据代理配置构建完整的代理 URL
Args:
proxy_config: 代理配置字典,包含 url, username, password, enabled
Returns:
完整的代理 URL如 socks5://user:pass@host:port
如果 enabled=False 或无配置,返回 None
"""
if not proxy_config:
return None
# 检查 enabled 字段,默认为 True兼容旧数据
if not proxy_config.get("enabled", True):
return None
proxy_url = proxy_config.get("url")
if not proxy_url:
return None
username = proxy_config.get("username")
password = proxy_config.get("password")
# 只要有用户名就添加认证信息(密码可以为空)
if username:
parsed = urlparse(proxy_url)
# URL 编码用户名和密码,处理特殊字符(如 @, :, /
encoded_username = quote(username, safe="")
encoded_password = quote(password, safe="") if password else ""
# 重新构建带认证的代理 URL
if encoded_password:
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
else:
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
if parsed.path:
auth_proxy += parsed.path
return auth_proxy
return proxy_url
class HTTPClientPool:
"""
@@ -121,6 +164,44 @@ class HTTPClientPool:
finally:
await client.aclose()
@classmethod
def create_client_with_proxy(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: Optional[httpx.Timeout] = None,
**kwargs: Any,
) -> httpx.AsyncClient:
"""
创建带代理配置的HTTP客户端
Args:
proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数
Returns:
配置好的 httpx.AsyncClient 实例
"""
config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
}
if timeout:
config["timeout"] = timeout
else:
config["timeout"] = httpx.Timeout(10.0, read=300.0)
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs)
return httpx.AsyncClient(**config)
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:

View File

@@ -120,7 +120,7 @@ class RedisClientManager:
if self._circuit_open_until and time.time() < self._circuit_open_until:
remaining = self._circuit_open_until - time.time()
logger.warning(
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
remaining,
self._last_error,
)
@@ -200,7 +200,7 @@ class RedisClientManager:
if self._consecutive_failures >= self._circuit_threshold:
self._circuit_open_until = time.time() + self._circuit_reset_seconds
logger.warning(
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
self._consecutive_failures,
@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
if _redis_manager is None:
_redis_manager = RedisClientManager()
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
if _redis_manager.get_client() is None:
await _redis_manager.initialize(require_redis=require_redis)
return _redis_manager.get_client()

View File

@@ -41,8 +41,8 @@ class CacheSize:
class ConcurrencyDefaults:
"""并发控制默认值"""
# 自适应并发初始限制(保守值
INITIAL_LIMIT = 3
# 自适应并发初始限制(宽松起步,遇到 429 再降低
INITIAL_LIMIT = 50
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制
COOLDOWN_AFTER_429_MINUTES = 5
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 1
INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时的缩容比例
DECREASE_MULTIPLIER = 0.7
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 100
MAX_CONCURRENT_LIMIT = 200
# 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10
# === 缓存用户预留比例 ===
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
# 0.1 表示缓存用户预留 10%,新用户可用 90%
CACHE_RESERVATION_RATIO = 0.1
class CircuitBreakerDefaults:
"""熔断器配置默认值(滑动窗口 + 半开状态模式)

View File

@@ -105,6 +105,13 @@ class Config:
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
# 可信代理配置
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1即信任最近一层代理
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
# 异常处理配置
# 设置为 True 时ProxyException 会传播到路由层以便记录 provider_request_headers
# 设置为 False 时,使用全局异常处理器统一处理
@@ -122,9 +129,9 @@ class Config:
# 并发控制配置
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL防止死锁
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))

View File

@@ -46,6 +46,11 @@ class BatchCommitter:
def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改"""
# 请求级事务由中间件统一 commit/rollback避免后台任务在请求中途误提交。
if session is None:
return
if session.info.get("managed_by_middleware"):
return
self._pending_sessions.add(session)
async def _batch_commit_loop(self):

View File

@@ -1,168 +0,0 @@
"""
统一的请求上下文
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
这确保了数据在各层之间传递时不会丢失。
使用方式:
1. Pipeline 层创建 RequestContext
2. 各层通过 context 访问和更新信息
3. Adapter 层使用 context 记录 Usage
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class RequestContext:
"""
请求上下文 - 贯穿整个请求生命周期
设计原则:
1. 在请求开始时创建,包含所有已知信息
2. 在请求执行过程中逐步填充 Provider 信息
3. 在请求结束时用于记录 Usage
"""
# ==================== 请求标识 ====================
request_id: str
# ==================== 认证信息 ====================
user: Any # User model
api_key: Any # ApiKey model
db: Any # Database session
# ==================== 请求信息 ====================
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
model: str # 用户请求的模型名
is_stream: bool = False
# ==================== 原始请求 ====================
original_headers: Dict[str, str] = field(default_factory=dict)
original_body: Dict[str, Any] = field(default_factory=dict)
# ==================== 客户端信息 ====================
client_ip: str = "unknown"
user_agent: str = ""
# ==================== 计时 ====================
start_time: float = field(default_factory=time.time)
# ==================== Provider 信息(请求执行后填充)====================
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# ==================== 模型映射信息 ====================
resolved_model: Optional[str] = None # 映射后的模型名
original_model: Optional[str] = None # 原始模型名(用于价格计算)
# ==================== 请求/响应头 ====================
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# ==================== 追踪信息 ====================
attempt_id: Optional[str] = None
# ==================== 能力需求 ====================
capability_requirements: Dict[str, bool] = field(default_factory=dict)
# 运行时计算的能力需求,来源于:
# 1. 用户 model_capability_settings
# 2. 用户 ApiKey.force_capabilities
# 3. 请求头 X-Require-Capability
# 4. 失败重试时动态添加
@classmethod
def create(
cls,
*,
db: Any,
user: Any,
api_key: Any,
api_format: str,
model: str,
is_stream: bool = False,
original_headers: Optional[Dict[str, str]] = None,
original_body: Optional[Dict[str, Any]] = None,
client_ip: str = "unknown",
user_agent: str = "",
request_id: Optional[str] = None,
) -> "RequestContext":
"""创建请求上下文"""
return cls(
request_id=request_id or str(uuid.uuid4()),
db=db,
user=user,
api_key=api_key,
api_format=api_format,
model=model,
is_stream=is_stream,
original_headers=original_headers or {},
original_body=original_body or {},
client_ip=client_ip,
user_agent=user_agent,
original_model=model, # 初始时原始模型等于请求模型
)
def update_provider_info(
self,
*,
provider_name: str,
provider_id: str,
endpoint_id: str,
provider_api_key_id: str,
resolved_model: Optional[str] = None,
) -> None:
"""更新 Provider 信息(请求执行后调用)"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.provider_api_key_id = provider_api_key_id
if resolved_model:
self.resolved_model = resolved_model
def update_headers(
self,
*,
request_headers: Optional[Dict[str, str]] = None,
response_headers: Optional[Dict[str, str]] = None,
) -> None:
"""更新请求/响应头"""
if request_headers:
self.provider_request_headers = request_headers
if response_headers:
self.provider_response_headers = response_headers
@property
def elapsed_ms(self) -> int:
"""计算已经过的时间(毫秒)"""
return int((time.time() - self.start_time) * 1000)
@property
def effective_model(self) -> str:
"""获取有效的模型名(映射后优先)"""
return self.resolved_model or self.model
@property
def billing_model(self) -> str:
"""获取计费模型名(原始模型优先)"""
return self.original_model or self.model
def to_metadata_dict(self) -> Dict[str, Any]:
"""转换为元数据字典(用于 Usage 记录)"""
return {
"api_format": self.api_format,
"provider": self.provider_name or "unknown",
"model": self.effective_model,
"original_model": self.billing_model,
"provider_id": self.provider_id,
"provider_endpoint_id": self.endpoint_id,
"provider_api_key_id": self.provider_api_key_id,
"provider_request_headers": self.provider_request_headers,
"provider_response_headers": self.provider_response_headers,
"attempt_id": self.attempt_id,
}

View File

@@ -10,8 +10,8 @@ class APIFormat(Enum):
"""API格式枚举 - 决定请求/响应的处理方式"""
CLAUDE = "CLAUDE" # Claude API 格式
OPENAI = "OPENAI" # OpenAI API 格式
CLAUDE_CLI = "CLAUDE_CLI" # Claude CLI API 格式(使用 authorization: Bearer
OPENAI = "OPENAI" # OpenAI API 格式
OPENAI_CLI = "OPENAI_CLI" # OpenAI CLI/Responses API 格式(用于 Claude Code 等客户端)
GEMINI = "GEMINI" # Google Gemini API 格式
GEMINI_CLI = "GEMINI_CLI" # Gemini CLI API 格式

View File

@@ -188,12 +188,16 @@ class ProviderNotAvailableException(ProviderException):
message: str,
provider_name: Optional[str] = None,
request_metadata: Optional[Any] = None,
upstream_status: Optional[int] = None,
upstream_response: Optional[str] = None,
):
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
)
self.upstream_status = upstream_status
self.upstream_response = upstream_response
class ProviderTimeoutException(ProviderException):
@@ -442,6 +446,36 @@ class EmbeddedErrorException(ProviderException):
self.error_status = error_status
class ProviderCompatibilityException(ProviderException):
"""Provider 兼容性错误异常 - 应该触发故障转移
用于处理因 Provider 不支持某些参数或功能导致的错误。
这类错误不是用户请求本身的问题,换一个 Provider 可能就能成功,应该触发故障转移。
常见场景:
- Unsupported parameter不支持的参数
- Unsupported model不支持的模型
- Unsupported feature不支持的功能
"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
status_code: int = 400,
upstream_error: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
self.upstream_error = upstream_error
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
)
# 覆盖状态码为 400保持与上游一致
self.status_code = status_code
class UpstreamClientException(ProxyException):
"""上游返回的客户端错误异常 - HTTP 4xx 错误,不应该重试

View File

@@ -9,7 +9,7 @@
输出策略:
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
- 文件: 始终保存 DEBUG 级别保留30天每日轮转
- 文件: 始终保存 DEBUG 级别保留30天按大小轮转 (100MB)
使用方式:
from src.core.logger import logger
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
if IS_DOCKER:
# 生产环境:禁用 backtrace 和 diagnose减少日志噪音
logger.add(
sys.stdout,
format=CONSOLE_FORMAT_PROD,
level=LOG_LEVEL,
filter=_log_filter, # type: ignore[arg-type]
colorize=False,
backtrace=False,
diagnose=False,
)
else:
logger.add(
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
log_dir = PROJECT_ROOT / "logs"
log_dir.mkdir(exist_ok=True)
# 文件日志通用配置
file_log_config = {
"format": FILE_FORMAT,
"filter": _log_filter,
"rotation": "100 MB",
"retention": "30 days",
"compression": "gz",
"enqueue": True,
"encoding": "utf-8",
"catch": True,
}
# 生产环境禁用详细堆栈
if IS_DOCKER:
file_log_config["backtrace"] = False
file_log_config["diagnose"] = False
# 主日志文件 - 所有级别
logger.add(
log_dir / "app.log",
format=FILE_FORMAT,
level="DEBUG",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**file_log_config, # type: ignore[arg-type]
)
# 错误日志文件 - 仅 ERROR 及以上
error_log_config = file_log_config.copy()
error_log_config["rotation"] = "50 MB"
logger.add(
log_dir / "error.log",
format=FILE_FORMAT,
level="ERROR",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**error_log_config, # type: ignore[arg-type]
)
# ============================================================================

View File

@@ -5,6 +5,7 @@
import time
from typing import AsyncGenerator, Generator, Optional
from starlette.requests import Request
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
@@ -150,9 +151,22 @@ def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes)
total_estimated = theoretical * workers
logger.info("数据库连接池配置")
if total_estimated > config.db_pool_warn_threshold:
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
safe_limit = config.pg_max_connections - config.pg_reserved_connections
logger.info(
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
config.db_pool_size,
config.db_max_overflow,
workers,
total_estimated,
safe_limit,
)
if total_estimated > safe_limit:
logger.warning(
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved)"
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
total_estimated,
safe_limit,
)
def _ensure_async_engine() -> AsyncEngine:
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
# 创建异步引擎
_async_engine = create_async_engine(
ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
# AsyncEngine 不能使用 QueuePool默认使用 AsyncAdaptedQueuePool
pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout,
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
"""获取异步数据库会话
.. deprecated::
此方法已废弃,项目统一使用同步 Session。
未来版本可能移除此方法。请使用 get_db() 代替。
"""
import warnings
warnings.warn(
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
DeprecationWarning,
stacklevel=2,
)
# 确保异步引擎已初始化
_ensure_async_engine()
@@ -220,16 +245,73 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
def get_db() -> Generator[Session, None, None]:
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
事务策略说明
============
本项目采用**混合事务管理**策略:
1. **LLM 请求路径**
- 由 PluginMiddleware 统一管理事务
- Service 层使用 db.flush() 使更改可见,但不提交
- 请求结束时由中间件统一 commit 或 rollback
- 例外UsageService.record_usage() 会显式 commit因为使用记录需要立即持久化
2. **管理后台 API**
- 路由层显式调用 db.commit()
- 提交后设置 request.state.tx_committed_by_route = True
- 中间件看到此标志后跳过 commit只负责 close
3. **后台任务/调度器**
- 使用独立 Session通过 create_session() 或 next(get_db())
- 自行管理事务生命周期
使用方式
========
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
路由层提交事务示例
==================
```python
@router.post("/example")
async def example(request: Request, db: Session = Depends(get_db)):
# ... 业务逻辑 ...
db.commit()
request.state.tx_committed_by_route = True # 告知中间件已提交
return {"message": "success"}
```
注意事项
========
- 本函数不自动提交事务
- 异常时会自动回滚
- 中间件管理模式下session 关闭由中间件负责
"""
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
if request is not None:
existing_db = getattr(getattr(request, "state", None), "db", None)
if isinstance(existing_db, Session):
yield existing_db
return
# 确保引擎已初始化
_ensure_engine()
db = _SessionLocal()
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state
# 并由中间件负责 commit/rollback/close这里不关闭避免流式响应提前释放会话
managed_by_middleware = bool(
request is not None
and hasattr(request, "state")
and getattr(request.state, "db_managed_by_middleware", False)
)
if managed_by_middleware:
request.state.db = db
db.info["managed_by_middleware"] = True
try:
yield db
# 不再自动 commit由业务代码显式管理事务
@@ -241,12 +323,13 @@ def get_db() -> Generator[Session, None, None]:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise
finally:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
if not managed_by_middleware:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def create_session() -> Session:
@@ -336,7 +419,7 @@ def init_admin_user(db: Session):
admin.set_password(config.admin_password)
db.add(admin)
db.commit() # 刷新以获取ID但不提交
db.flush() # 分配ID但不提交事务(由外层 init_db 统一 commit
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e:

View File

@@ -3,15 +3,11 @@
采用模块化架构设计
"""
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from src.api.admin import router as admin_router
from src.api.announcements import router as announcement_router
@@ -39,20 +35,18 @@ async def initialize_providers():
"""从数据库初始化提供商(仅用于日志记录)"""
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.database import get_db
from src.database.database import create_session
from src.models.database import Provider
try:
# 创建数据库会话
db_gen = get_db()
db: Session = next(db_gen)
db: Session = create_session()
try:
# 从数据库加载所有活跃的提供商
providers = (
db.query(Provider)
.filter(Provider.is_active == True)
.filter(Provider.is_active.is_(True))
.order_by(Provider.provider_priority.asc())
.all()
)
@@ -75,7 +69,7 @@ async def initialize_providers():
finally:
db.close()
except Exception as e:
except Exception:
logger.exception("从数据库初始化提供商失败")
@@ -125,6 +119,7 @@ async def lifespan(app: FastAPI):
logger.info("初始化全局Redis客户端...")
from src.clients.redis_client import get_redis_client
redis_client = None
try:
redis_client = await get_redis_client(require_redis=config.require_redis)
if redis_client:
@@ -136,6 +131,7 @@ async def lifespan(app: FastAPI):
logger.exception("[ERROR] Redis连接失败应用启动中止")
raise
logger.warning(f"Redis连接失败但配置允许降级将继续使用内存模式: {e}")
redis_client = None
# 初始化并发管理器内部会使用Redis
logger.info("初始化并发管理器...")
@@ -300,33 +296,6 @@ app.include_router(dashboard_router) # 仪表盘端点
app.include_router(public_router) # 公开API端点用户可查看提供商和模型
app.include_router(monitoring_router) # 监控端点
# 静态文件服务(前端构建产物)
# 检查前端构建目录是否存在
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
if frontend_dist.exists():
# 挂载静态资源目录
app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets")
# SPA catch-all路由 - 必须放在最后
@app.get("/{full_path:path}")
async def serve_spa(request: Request, full_path: str):
"""
处理所有未匹配的GET请求返回index.html供前端路由处理
仅对非API路径生效
"""
# 如果是API路径不处理
if full_path.startswith("api/") or full_path.startswith("v1/"):
raise HTTPException(status_code=404, detail="Not Found")
# 返回index.html让前端路由处理
index_file = frontend_dist / "index.html"
if index_file.exists():
return FileResponse(str(index_file))
else:
raise HTTPException(status_code=404, detail="Frontend not built")
else:
logger.warning("前端构建目录不存在,前端路由将无法使用")
def main():

View File

@@ -1,38 +1,43 @@
"""
统一的插件中间件
统一的插件中间件(纯 ASGI 实现)
负责协调所有插件的调用
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题。
"""
import hashlib
import time
from typing import Any, Awaitable, Callable, Optional
from typing import Optional
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response as StarletteResponse
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from src.config import config
from src.core.logger import logger
from src.database import get_db
from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware(BaseHTTPMiddleware):
class PluginMiddleware:
"""
统一的插件调用中间件
统一的插件调用中间件(纯 ASGI 实现)
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
- 纯 ASGI 可以直接透传流式响应,无额外开销
"""
def __init__(self, app: Any) -> None:
super().__init__(app)
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
@@ -62,175 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
"/v1/completions",
]
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
) -> StarletteResponse:
"""处理请求并调用相应插件"""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
db_func = get_db
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
if get_db in request.app.dependency_overrides:
db_func = request.app.dependency_overrides[get_db]
logger.debug("Using overridden get_db from app.dependency_overrides")
# 1. 限流检查(在调用下游之前
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
# 创建数据库会话供需要的插件或后续处理使
db_gen = db_func()
db = None
response = None
exception_to_raise = None
# 2. 预处理插件调
await self._call_pre_request_plugins(request)
# 用于捕获响应状态码
response_status_code: int = 0
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
await send(message)
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
# 获取数据库会话
db = next(db_gen)
request.state.db = db
# 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
headers = rate_limit_result.headers or {}
raise HTTPException(
status_code=429,
detail=rate_limit_result.message or "Rate limit exceeded",
headers=headers,
)
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 处理请求
response = await call_next(request)
# 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try:
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
db.rollback()
# 返回 500 错误,因为数据可能不一致
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "database_error",
"message": "数据保存失败,请重试",
},
},
)
# 跳过后处理插件,直接返回错误响应
return response
# 4. 后处理插件调用(监控等,非关键操作)
# 这些操作失败不应影响用户响应
await self._call_post_request_plugins(request, response, start_time)
# 注意不在此处添加限流响应头因为在BaseHTTPMiddleware中
# 响应返回后修改headers会导致Content-Length不匹配错误
# 限流响应头已在返回429错误时正确包含见上面的HTTPException
except RuntimeError as e:
if str(e) == "No response returned.":
if db:
db.rollback()
logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time)
if db:
try:
db.commit()
except Exception:
pass
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "internal_error",
"message": "Internal server error: downstream handler returned no response.",
},
},
)
else:
exception_to_raise = e
await self.app(scope, receive, send_wrapper)
except Exception as e:
# 回滚数据库事务
if db:
db.rollback()
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
# 尝试提交错误日志
if db:
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
事务策略:
- 如果 request.state.tx_committed_by_route 为 True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
这避免了双重提交的问题,同时保持向后兼容。
"""
from sqlalchemy.orm import Session
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
try:
db.commit()
except:
pass
exception_to_raise = e
finally:
# 确保数据库会话被正确关闭
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
if db is not None:
try:
# 检查会话是否可以安全地进行回滚
# 只有当没有进行中的事务操作时才尝试回滚
if db.is_active and not db.get_transaction().is_active:
# 事务不在活跃状态,可以安全回滚
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
elif db.is_active:
# 事务在活跃状态,尝试回滚
try:
db.rollback()
except Exception as rollback_error:
# 回滚失败(可能是 commit 正在进行中),忽略错误
logger.debug(f"Rollback skipped: {rollback_error}")
except Exception:
# 检查状态时出错,忽略
pass
# 通过触发生成器的 finally 块来关闭会话(标准模式)
# 这会调用 get_db() 的 finally 块,执行 db.close()
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
finally:
# 关闭会话,归还连接到连接池
try:
next(db_gen, None)
except StopIteration:
# 正常情况:生成器已耗尽
pass
except Exception as cleanup_error:
# 忽略 IllegalStateChangeError 等清理错误
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
logger.warning(f"Database cleanup warning: {cleanup_error}")
# 在 finally 块之后处理异常和响应
if exception_to_raise:
raise exception_to_raise
return response
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 可能包含多个 IP取第一个
return forwarded_for.split(",")[0].strip()
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
@@ -250,7 +239,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
return False
async def _get_rate_limit_key_and_config(
self, request: Request, db: Session
self, request: Request
) -> tuple[Optional[str], Optional[int]]:
"""
获取速率限制的key和配置
@@ -272,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.startswith("Bearer "):
if auth_header.lower().startswith("bearer "):
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
import hashlib
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
@@ -318,14 +305,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 如果没有限流插件,允许通过
return None
# 获取数据库会话
db = getattr(request.state, "db", None)
if not db:
logger.warning("速率限制检查:无法获取数据库会话")
return None
# 获取速率限制的key和配置从数据库
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
# 获取速率限制的 key 和配置
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
if not key:
# 不需要限流的端点(如未分类路径),静默跳过
return None
@@ -336,7 +317,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
key=key,
endpoint=request.url.path,
method=request.method,
rate_limit=rate_limit_value, # 传入数据库配置的限制值
rate_limit=rate_limit_value, # 传入配置的限制值
)
# 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult):
@@ -349,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
)
else:
# 限流触发,记录日志
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
return result
return None
except Exception as e:
@@ -362,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
pass
async def _call_post_request_plugins(
self, request: Request, response: StarletteResponse, start_time: float
self, request: Request, status_code: int, start_time: float
) -> None:
"""调用请求后的插件"""
@@ -375,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(response.status_code),
"status_class": f"{response.status_code // 100}xx",
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
# 记录请求计数
@@ -398,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
duration = time.time() - start_time
@@ -410,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": request.state.request_id,
"request_id": getattr(request.state, "request_id", ""),
"duration": duration,
},
)

View File

@@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from src.core.enums import APIFormat, ProviderBillingType
class ProxyConfig(BaseModel):
"""代理配置"""
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
password: Optional[str] = Field(None, max_length=500, description="代理密码")
enabled: bool = Field(True, description="是否启用代理false 时保留配置但不使用)")
@field_validator("url")
@classmethod
def validate_proxy_url(cls, v: str) -> str:
"""验证代理 URL 格式"""
from urllib.parse import urlparse
v = v.strip()
# 检查禁止的字符(防止注入)
if "\n" in v or "\r" in v:
raise ValueError("代理 URL 包含非法字符")
# 验证协议(不支持 SOCKS4
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
# 验证 URL 结构
parsed = urlparse(v)
if not parsed.netloc:
raise ValueError("代理 URL 必须包含有效的 host")
# 禁止 URL 中内嵌认证信息,强制使用独立字段
if parsed.username or parsed.password:
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
return v
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
@@ -107,20 +143,6 @@ class CreateProviderRequest(BaseModel):
if not re.match(r"^https?://", v, re.IGNORECASE):
v = f"https://{v}"
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v
@field_validator("billing_type")
@@ -179,6 +201,7 @@ class CreateEndpointRequest(BaseModel):
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@field_validator("name")
@classmethod
@@ -195,19 +218,6 @@ class CreateEndpointRequest(BaseModel):
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@field_validator("api_format")
@@ -247,6 +257,7 @@ class UpdateEndpointRequest(BaseModel):
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
# 复用验证器
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)

View File

@@ -6,7 +6,7 @@ import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from ..core.enums import UserRole
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
active_models_count: int = 0
api_keys_count: int = 0
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 模型管理 ==========
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
global_model_name: Optional[str] = None
global_model_display_name: Optional[str] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ModelDetailResponse(BaseModel):
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 系统设置 ==========

View File

@@ -5,7 +5,7 @@ Provider API Key相关的API模型
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class ProviderAPIKeyBase(BaseModel):
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ProviderAPIKeyStats(BaseModel):

View File

@@ -27,8 +27,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship
from ..config import config
from ..core.enums import ProviderBillingType, UserRole
@@ -539,6 +538,9 @@ class ProviderEndpoint(Base):
# 额外配置
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
# 代理配置
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password}
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
@@ -813,7 +815,9 @@ class Model(Base):
def get_effective_supports_image_generation(self) -> bool:
return self._get_effective_capability("supports_image_generation", False)
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
def select_provider_model_name(
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
) -> str:
"""按优先级选择要使用的 Provider 模型名称
如果配置了 provider_model_aliases按优先级选择数字越小越优先
@@ -822,6 +826,7 @@ class Model(Base):
Args:
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
"""
import hashlib
@@ -840,6 +845,13 @@ class Model(Base):
if not isinstance(name, str) or not name.strip():
continue
# 检查 api_formats 作用域(如果配置了且当前有 api_format
alias_api_formats = raw.get("api_formats")
if api_format and alias_api_formats:
# 如果配置了作用域,只有匹配时才生效
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
continue
raw_priority = raw.get("priority", 1)
try:
priority = int(raw_priority)

View File

@@ -6,7 +6,9 @@ import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from src.models.admin_requests import ProxyConfig
# ========== ProviderEndpoint CRUD ==========
@@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel):
# 额外配置
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON")
# 代理配置
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
@@ -45,24 +50,9 @@ class ProviderEndpointCreate(BaseModel):
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URLSSRF 防护)"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@@ -79,31 +69,18 @@ class ProviderEndpointUpdate(BaseModel):
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
is_active: Optional[bool] = Field(default=None, description="是否启用")
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
"""验证 API URLSSRF 防护)"""
"""验证 API URL"""
if v is None:
return v
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@@ -133,6 +110,9 @@ class ProviderEndpointResponse(BaseModel):
# 额外配置
config: Optional[Dict[str, Any]] = None
# 代理配置(响应中密码已脱敏)
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)")
# 统计(从 Keys 聚合)
total_keys: int = Field(default=0, description="总 Key 数量")
active_keys: int = Field(default=0, description="活跃 Key 数量")
@@ -141,8 +121,7 @@ class ProviderEndpointResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== ProviderAPIKey 相关(新架构) ==========
@@ -384,8 +363,7 @@ class EndpointAPIKeyResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 健康监控相关 ==========
@@ -535,8 +513,7 @@ class ProviderWithEndpointsSummary(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 健康监控可视化模型 ==========

View File

@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
# ========== 阶梯计费相关模型 ==========
@@ -238,8 +238,8 @@ class GlobalModelResponse(BaseModel):
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
# 阶梯计费配置
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置"
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
default=None, description="阶梯计费配置"
)
# Key 能力配置 - 模型支持的能力列表
supported_capabilities: Optional[List[str]] = Field(
@@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel):
created_at: datetime
updated_at: Optional[datetime]
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class GlobalModelWithStats(GlobalModelResponse):

View File

@@ -3,6 +3,7 @@ JWT认证插件
支持JWT Bearer token认证
"""
import hashlib
from typing import Optional
from fastapi import Request
@@ -46,12 +47,12 @@ class JwtAuthPlugin(AuthPlugin):
logger.debug("未找到JWT token")
return None
# 记录认证尝试的详细信息
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
try:
# 验证JWT token
payload = AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
logger.debug(f"JWT token验证成功, payload: {payload}")
# 从payload中提取用户信息

View File

@@ -63,14 +63,16 @@ class JWTBlacklistService:
if ttl_seconds <= 0:
# Token 已经过期,不需要加入黑名单
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
return True
# 存储到 Redis设置 TTL 为 Token 过期时间
# 值存储为原因字符串
await redis_client.setex(redis_key, ttl_seconds, reason)
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
return True
except Exception as e:
@@ -109,7 +111,8 @@ class JWTBlacklistService:
if exists:
# 获取黑名单原因(可选)
reason = await redis_client.get(redis_key)
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
return True
return False
@@ -148,9 +151,11 @@ class JWTBlacklistService:
deleted = await redis_client.delete(redis_key)
if deleted:
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
else:
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
return bool(deleted)

View File

@@ -3,6 +3,7 @@
"""
import os
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
@@ -93,8 +94,8 @@ class AuthService:
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
# 使用缓存查询用户
user = await UserCacheService.get_user_by_email(db, email)
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
@@ -109,13 +110,10 @@ class AuthService:
return None
# 更新最后登录时间
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
db_user = db.query(User).filter(User.id == user.id).first()
if db_user:
db_user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@@ -172,7 +170,8 @@ class AuthService:
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
return user, key_record
@staticmethod
@@ -198,7 +197,10 @@ class AuthService:
if user.role == UserRole.ADMIN:
return True
if user.role.value >= required_role.value:
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin'
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
# 未知用户角色默认 -1拒绝未知要求角色默认 999拒绝
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
return True
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
@@ -230,7 +232,7 @@ class AuthService:
)
if success:
user_id = payload.get("sub")
user_id = payload.get("user_id")
logger.info(f"用户登出成功: user_id={user_id}")
return success

View File

@@ -59,7 +59,6 @@ from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_reservation import (
AdaptiveReservationManager,
ReservationResult,
get_adaptive_reservation_manager,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
@@ -112,8 +111,6 @@ class CacheAwareScheduler:
- 健康度监控
"""
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
CACHE_RESERVATION_RATIO = 0.3
# 优先级模式常量
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
@@ -121,8 +118,17 @@ class CacheAwareScheduler:
PRIORITY_MODE_PROVIDER,
PRIORITY_MODE_GLOBAL_KEY,
}
# 调度模式常量
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
ALLOWED_SCHEDULING_MODES = {
SCHEDULING_MODE_FIXED_ORDER,
SCHEDULING_MODE_CACHE_AFFINITY,
}
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
def __init__(
self, redis_client=None, priority_mode: Optional[str] = None, scheduling_mode: Optional[str] = None
):
"""
初始化调度器
@@ -132,12 +138,16 @@ class CacheAwareScheduler:
Args:
redis_client: Redis客户端可选
priority_mode: 候选排序策略provider | global_key
scheduling_mode: 调度模式fixed_order | cache_affinity
"""
self.redis = redis_client
self.priority_mode = self._normalize_priority_mode(
priority_mode or self.PRIORITY_MODE_PROVIDER
)
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
self.scheduling_mode = self._normalize_scheduling_mode(
scheduling_mode or self.SCHEDULING_MODE_CACHE_AFFINITY
)
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}, 调度模式: {self.scheduling_mode}")
# 初始化子组件(将在第一次使用时异步初始化)
self._affinity_manager: Optional[CacheAffinityManager] = None
@@ -673,14 +683,19 @@ class CacheAwareScheduler:
f"(api_format={target_format.value}, model={model_name})"
)
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识
if affinity_key and candidates:
candidates = await self._apply_cache_affinity(
candidates=candidates,
affinity_key=affinity_key,
api_format=target_format,
global_model_id=global_model_id,
)
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
if affinity_key and candidates:
candidates = await self._apply_cache_affinity(
candidates=candidates,
affinity_key=affinity_key,
api_format=target_format,
global_model_id=global_model_id,
)
else:
# 固定顺序模式:标记所有候选为非缓存
for candidate in candidates:
candidate.is_cached = False
return candidates, global_model_id
@@ -1060,6 +1075,22 @@ class CacheAwareScheduler:
self.priority_mode = normalized
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
def _normalize_scheduling_mode(self, mode: Optional[str]) -> str:
normalized = (mode or "").strip().lower()
if normalized not in self.ALLOWED_SCHEDULING_MODES:
if normalized:
logger.warning(f"[CacheAwareScheduler] 无效的调度模式 '{mode}',回退为 cache_affinity")
return self.SCHEDULING_MODE_CACHE_AFFINITY
return normalized
def set_scheduling_mode(self, mode: Optional[str]) -> None:
"""运行时更新调度模式"""
normalized = self._normalize_scheduling_mode(mode)
if normalized == self.scheduling_mode:
return
self.scheduling_mode = normalized
logger.debug(f"[CacheAwareScheduler] 切换调度模式为: {self.scheduling_mode}")
def _apply_priority_mode_sort(
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
) -> List[ProviderCandidate]:
@@ -1286,7 +1317,6 @@ class CacheAwareScheduler:
return {
"scheduler": "cache_aware",
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],
@@ -1307,6 +1337,7 @@ _scheduler: Optional[CacheAwareScheduler] = None
async def get_cache_aware_scheduler(
redis_client=None,
priority_mode: Optional[str] = None,
scheduling_mode: Optional[str] = None,
) -> CacheAwareScheduler:
"""
获取全局CacheAwareScheduler实例
@@ -1317,6 +1348,7 @@ async def get_cache_aware_scheduler(
Args:
redis_client: Redis客户端可选
priority_mode: 外部覆盖的优先级模式provider | global_key
scheduling_mode: 外部覆盖的调度模式fixed_order | cache_affinity
Returns:
CacheAwareScheduler实例
@@ -1324,8 +1356,13 @@ async def get_cache_aware_scheduler(
global _scheduler
if _scheduler is None:
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
elif priority_mode:
_scheduler.set_priority_mode(priority_mode)
_scheduler = CacheAwareScheduler(
redis_client, priority_mode=priority_mode, scheduling_mode=scheduling_mode
)
else:
if priority_mode:
_scheduler.set_priority_mode(priority_mode)
if scheduling_mode:
_scheduler.set_scheduling_mode(scheduling_mode)
return _scheduler

View File

@@ -1,5 +1,21 @@
"""
Model 映射缓存服务 - 减少模型查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
"""
import time
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
class ModelCacheService:
"""Model 映射缓存服务"""
"""Model 映射缓存服务
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.MODEL

View File

@@ -1,254 +0,0 @@
"""
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
"""
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
class ProviderCacheService:
"""Provider 配置缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.PROVIDER
@staticmethod
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
"""
获取 Provider带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
Provider 对象或 None
"""
cache_key = CacheKeys.provider_by_id(provider_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Provider 缓存命中: {provider_id}")
return ProviderCacheService._dict_to_provider(cached_data)
# 2. 缓存未命中,查询数据库
provider = db.query(Provider).filter(Provider.id == provider_id).first()
# 3. 写入缓存
if provider:
provider_dict = ProviderCacheService._provider_to_dict(provider)
await CacheService.set(
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider 已缓存: {provider_id}")
return provider
@staticmethod
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
"""
获取 Endpoint带缓存
Args:
db: 数据库会话
endpoint_id: Endpoint ID
Returns:
ProviderEndpoint 对象或 None
"""
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
return ProviderCacheService._dict_to_endpoint(cached_data)
# 2. 缓存未命中,查询数据库
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
# 3. 写入缓存
if endpoint:
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
await CacheService.set(
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
return endpoint
@staticmethod
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
"""
获取 API Key带缓存
Args:
db: 数据库会话
api_key_id: API Key ID
Returns:
ProviderAPIKey 对象或 None
"""
cache_key = CacheKeys.api_key_by_id(api_key_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"API Key 缓存命中: {api_key_id}")
return ProviderCacheService._dict_to_api_key(cached_data)
# 2. 缓存未命中,查询数据库
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
# 3. 写入缓存
if api_key:
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
await CacheService.set(
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"API Key 已缓存: {api_key_id}")
return api_key
@staticmethod
async def invalidate_provider_cache(provider_id: str):
"""
清除 Provider 缓存
Args:
provider_id: Provider ID
"""
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
logger.debug(f"Provider 缓存已清除: {provider_id}")
@staticmethod
async def invalidate_endpoint_cache(endpoint_id: str):
"""
清除 Endpoint 缓存
Args:
endpoint_id: Endpoint ID
"""
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
@staticmethod
async def invalidate_api_key_cache(api_key_id: str):
"""
清除 API Key 缓存
Args:
api_key_id: API Key ID
"""
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
logger.debug(f"API Key 缓存已清除: {api_key_id}")
@staticmethod
def _provider_to_dict(provider: Provider) -> dict:
"""将 Provider 对象转换为字典(用于缓存)"""
return {
"id": provider.id,
"name": provider.name,
"api_format": provider.api_format,
"base_url": provider.base_url,
"is_active": provider.is_active,
"priority": provider.priority,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"config": provider.config,
"description": provider.description,
}
@staticmethod
def _dict_to_provider(provider_dict: dict) -> Provider:
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
from datetime import datetime
provider = Provider(
id=provider_dict["id"],
name=provider_dict["name"],
api_format=provider_dict["api_format"],
base_url=provider_dict.get("base_url"),
is_active=provider_dict["is_active"],
priority=provider_dict.get("priority", 0),
rpm_limit=provider_dict.get("rpm_limit"),
rpm_used=provider_dict.get("rpm_used", 0),
config=provider_dict.get("config"),
description=provider_dict.get("description"),
)
if provider_dict.get("rpm_reset_at"):
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
return provider
@staticmethod
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
"""将 Endpoint 对象转换为字典"""
return {
"id": endpoint.id,
"provider_id": endpoint.provider_id,
"name": endpoint.name,
"base_url": endpoint.base_url,
"is_active": endpoint.is_active,
"priority": endpoint.priority,
"weight": endpoint.weight,
"custom_path": endpoint.custom_path,
"config": endpoint.config,
}
@staticmethod
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
"""从字典重建 Endpoint 对象"""
endpoint = ProviderEndpoint(
id=endpoint_dict["id"],
provider_id=endpoint_dict["provider_id"],
name=endpoint_dict["name"],
base_url=endpoint_dict["base_url"],
is_active=endpoint_dict["is_active"],
priority=endpoint_dict.get("priority", 0),
weight=endpoint_dict.get("weight", 1.0),
custom_path=endpoint_dict.get("custom_path"),
config=endpoint_dict.get("config"),
)
return endpoint
@staticmethod
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
"""将 API Key 对象转换为字典"""
return {
"id": api_key.id,
"endpoint_id": api_key.endpoint_id,
"key_value": api_key.key_value,
"is_active": api_key.is_active,
"max_rpm": api_key.max_rpm,
"current_rpm": api_key.current_rpm,
"health_score": api_key.health_score,
"circuit_breaker_state": api_key.circuit_breaker_state,
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
}
@staticmethod
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
"""从字典重建 API Key 对象"""
api_key = ProviderAPIKey(
id=api_key_dict["id"],
endpoint_id=api_key_dict["endpoint_id"],
key_value=api_key_dict["key_value"],
is_active=api_key_dict["is_active"],
max_rpm=api_key_dict.get("max_rpm"),
current_rpm=api_key_dict.get("current_rpm", 0),
health_score=api_key_dict.get("health_score", 1.0),
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
)
return api_key

View File

@@ -1,5 +1,22 @@
"""
用户缓存服务 - 减少数据库查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
user = await UserCacheService.get_user_by_id(db, user_id)
await UserCacheService.invalidate_user_cache(user_id, email)
"""
from typing import Optional
@@ -12,9 +29,12 @@ from src.core.logger import logger
from src.models.database import User
class UserCacheService:
"""用户缓存服务"""
"""用户缓存服务
提供 User 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.USER

View File

@@ -13,6 +13,7 @@ from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
from src.models.database import Model, Provider
from src.api.base.models_service import invalidate_models_list_cache
from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.cache.model_cache import ModelCacheService
@@ -75,6 +76,10 @@ class ModelService:
)
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
return model
except IntegrityError as e:
@@ -197,6 +202,9 @@ class ModelService:
cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id)
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
return model
except IntegrityError as e:
@@ -261,6 +269,9 @@ class ModelService:
cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
except Exception as e:
@@ -295,6 +306,9 @@ class ModelService:
cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id)
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
status = "可用" if is_available else "不可用"
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
return model
@@ -358,6 +372,9 @@ class ModelService:
for model in created_models:
db.refresh(model)
logger.info(f"批量创建 {len(created_models)} 个模型成功")
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
except IntegrityError as e:
db.rollback()
logger.error(f"批量创建模型失败: {str(e)}")

View File

@@ -15,6 +15,7 @@ from src.core.enums import APIFormat
from src.core.exceptions import (
ConcurrencyLimitError,
ProviderAuthException,
ProviderCompatibilityException,
ProviderException,
ProviderNotAvailableException,
ProviderRateLimitException,
@@ -69,24 +70,31 @@ class ErrorClassifier:
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
# 这里主要用于匹配非标准格式或第三方代理的错误消息
#
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key
# 这类错误应该触发故障转移,而不是直接返回给用户
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
"could not process image", # 图片处理失败
"image too large", # 图片过大
"invalid image", # 无效图片
"unsupported image", # 不支持的图片格式
"content_policy_violation", # 内容违规
"invalid_api_key", # 无效的 API Key不同于认证失败
"context_length_exceeded", # 上下文长度超限
"content_length_limit", # 请求内容长度超限 (Claude API)
"max_tokens", # token 数超限
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
# 注意:移除了 "max_tokens",因为 max_tokens 相关错误可能是 Provider 兼容性问题
# 如 "Unsupported parameter: 'max_tokens' is not supported with this model"
# 这类错误应由 COMPATIBILITY_ERROR_PATTERNS 处理
"invalid_prompt", # 无效的提示词
"content too long", # 内容过长
"input is too long", # 输入过长 (AWS)
"message is too long", # 消息过长
"prompt is too long", # Prompt 超长(第三方代理常见格式)
"image exceeds", # 图片超出限制
"pdf too large", # PDF 过大
"file too large", # 文件过大
"tool_use_id", # tool_result 引用了不存在的 tool_use兼容非标准代理
"validationexception", # AWS 验证异常
)
def __init__(
@@ -110,18 +118,137 @@ class ErrorClassifier:
# 表示客户端错误的 error type不区分大小写
# 这些 type 表明是请求本身的问题,不应重试
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
"invalid_request_error", # Claude/OpenAI 标准客户端错误类型
"invalid_argument", # Gemini 参数错误
"failed_precondition", # Gemini 前置条件错误
# Claude/OpenAI 标准
"invalid_request_error",
# Gemini
"invalid_argument",
"failed_precondition",
# AWS
"validationexception",
# 通用
"validation_error",
"bad_request",
)
# 表示客户端错误的 reason/code 字段值
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
"CONTEXT_LENGTH_EXCEEDED",
"MAX_TOKENS_EXCEEDED",
"INVALID_CONTENT",
"CONTENT_POLICY_VIOLATION",
)
# Provider 兼容性错误模式 - 这类错误应该触发故障转移
# 因为换一个 Provider 可能就能成功
COMPATIBILITY_ERROR_PATTERNS: Tuple[str, ...] = (
"unsupported parameter", # 不支持的参数
"unsupported model", # 不支持的模型
"unsupported feature", # 不支持的功能
"not supported with this model", # 此模型不支持
"model does not support", # 模型不支持
"parameter is not supported", # 参数不支持
"feature is not supported", # 功能不支持
"not available for this model", # 此模型不可用
)
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
"""
解析错误响应为结构化数据
支持多种格式:
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
- {"error": {"message": "...", "__type": "..."}} (AWS)
- {"errorMessage": "..."} (Lambda)
- {"error": "..."}
- {"message": "...", "reason": "..."}
Returns:
结构化的错误信息: {
"type": str, # 错误类型
"message": str, # 错误消息
"reason": str, # 错误原因/代码
"raw": str, # 原始文本
}
"""
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
if not error_text:
return result
try:
data = json.loads(error_text)
# 格式 1: {"error": {"type": "...", "message": "..."}}
if isinstance(data.get("error"), dict):
error_obj = data["error"]
result["type"] = str(error_obj.get("type", ""))
result["message"] = str(error_obj.get("message", ""))
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
# __type 直接在 error 对象中,而不是嵌套在 message 里
if "__type" in error_obj:
result["type"] = result["type"] or str(error_obj.get("__type", ""))
if "reason" in error_obj:
result["reason"] = str(error_obj.get("reason", ""))
if "code" in error_obj:
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
# 支持多种嵌套格式:
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
if result["message"].startswith("{"):
try:
nested = json.loads(result["message"])
if isinstance(nested, dict):
# AWS 格式
if "__type" in nested:
result["type"] = result["type"] or str(nested.get("__type", ""))
result["message"] = str(nested.get("message", result["message"]))
result["reason"] = str(nested.get("reason", ""))
# 第三方代理格式: {"error": {"message": "..."}}
elif isinstance(nested.get("error"), dict):
inner_error = nested["error"]
inner_msg = str(inner_error.get("message", ""))
if inner_msg:
result["message"] = inner_msg
# 简单格式: {"message": "..."}
elif "message" in nested:
result["message"] = str(nested["message"])
except json.JSONDecodeError:
pass
# 格式 2: {"error": "..."}
elif isinstance(data.get("error"), str):
result["message"] = str(data["error"])
# 格式 3: {"errorMessage": "..."} (Lambda)
elif "errorMessage" in data:
result["message"] = str(data["errorMessage"])
# 格式 4: {"message": "...", "reason": "..."}
elif "message" in data:
result["message"] = str(data["message"])
result["reason"] = str(data.get("reason", ""))
# 提取顶层的 reason/code
if not result["reason"]:
result["reason"] = str(data.get("reason", data.get("code", "")))
except (json.JSONDecodeError, TypeError, KeyError):
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
return result
def _is_client_error(self, error_text: Optional[str]) -> bool:
"""
检测错误响应是否为客户端错误(不应重试)
判断逻辑:
判断逻辑(按优先级)
1. 检查 error.type 是否为已知的客户端错误类型
2. 检查错误文本是否包含已知的客户端错误模式
2. 检查 reason/code 是否为已知的客户端错误原因
3. 回退到关键词匹配
Args:
error_text: 错误响应文本
@@ -132,67 +259,72 @@ class ErrorClassifier:
if not error_text:
return False
# 尝试解析 JSON 并检查 error type
try:
data = json.loads(error_text)
if isinstance(data.get("error"), dict):
error_type = data["error"].get("type", "")
if error_type and any(
t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES
):
return True
except (json.JSONDecodeError, TypeError, KeyError):
pass
parsed = self._parse_error_response(error_text)
# 回退到关键词匹配
error_lower = error_text.lower()
return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS)
# 1. 检查 error type
if parsed["type"]:
error_type_lower = parsed["type"].lower()
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
return True
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
# 2. 检查 reason/code
if parsed["reason"]:
reason_upper = parsed["reason"].upper()
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
return True
# 3. 回退到关键词匹配(合并 message 和 raw
search_text = f"{parsed['message']} {parsed['raw']}".lower()
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
def _is_compatibility_error(self, error_text: Optional[str]) -> bool:
"""
错误响应中提取错误消息
检测错误响应是否为 Provider 兼容性错误(应触发故障转移)
支持格式:
- {"error": {"message": "..."}} (OpenAI/Claude)
- {"error": {"type": "...", "message": "..."}}
- {"error": "..."}
- {"message": "..."}
这类错误是因为 Provider 不支持某些参数或功能导致的,
换一个 Provider 可能就能成功。
Args:
error_text: 错误响应文本
Returns:
提取的错误消息,如果无法解析则返回原始文本
是否为兼容性错误
"""
if not error_text:
return False
search_text = error_text.lower()
return any(pattern.lower() in search_text for pattern in self.COMPATIBILITY_ERROR_PATTERNS)
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
"""
从错误响应中提取错误消息
Args:
error_text: 错误响应文本
Returns:
提取的错误消息
"""
if not error_text:
return None
try:
data = json.loads(error_text)
parsed = self._parse_error_response(error_text)
# {"error": {"message": "..."}} 或 {"error": {"type": "...", "message": "..."}}
if isinstance(data.get("error"), dict):
error_obj = data["error"]
message = error_obj.get("message", "")
error_type = error_obj.get("type", "")
if message:
if error_type:
return f"{error_type}: {message}"
return str(message)
# 构建可读的错误消息
parts = []
if parsed["type"]:
parts.append(parsed["type"])
if parsed["reason"]:
parts.append(f"[{parsed['reason']}]")
if parsed["message"]:
parts.append(parsed["message"])
# {"error": "..."}
if isinstance(data.get("error"), str):
return str(data["error"])
# {"message": "..."}
if isinstance(data.get("message"), str):
return str(data["message"])
except (json.JSONDecodeError, TypeError, KeyError):
pass
if parts:
return ": ".join(parts) if len(parts) > 1 else parts[0]
# 无法解析,返回原始文本(截断)
return error_text[:500] if len(error_text) > 500 else error_text
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
def classify(
self,
@@ -328,6 +460,16 @@ class ErrorClassifier:
),
)
# 400 错误:先检查是否为 Provider 兼容性错误(应触发故障转移)
if status == 400 and self._is_compatibility_error(error_response_text):
logger.info(f"检测到 Provider 兼容性错误,将触发故障转移: {extracted_message}")
return ProviderCompatibilityException(
message=extracted_message or "Provider 不支持此请求",
provider_name=provider_name,
status_code=400,
upstream_error=error_response_text,
)
# 400 错误:检查是否为客户端请求错误(不应重试)
if status == 400 and self._is_client_error(error_response_text):
logger.info(f"检测到客户端请求错误,不进行重试: {extracted_message}")

View File

@@ -102,9 +102,15 @@ class FallbackOrchestrator:
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
self.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
self.cache_scheduler = await get_cache_aware_scheduler(
self.redis,
priority_mode=priority_mode,
scheduling_mode=scheduling_mode,
)
else:
# 确保运行时配置变更能生效
@@ -113,7 +119,13 @@ class FallbackOrchestrator:
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
self.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
self.cache_scheduler.set_priority_mode(priority_mode)
self.cache_scheduler.set_scheduling_mode(scheduling_mode)
# 确保 cache_scheduler 内部组件也已初始化
await self.cache_scheduler._ensure_initialized()
@@ -415,6 +427,9 @@ class FallbackOrchestrator:
)
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,
@@ -427,6 +442,9 @@ class FallbackOrchestrator:
# 未知错误:记录失败并抛出
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,

View File

@@ -1,61 +0,0 @@
"""响应标准化服务,用于 STANDARD 模式下的响应格式验证和补全"""
from typing import Any, Dict, Optional
from src.core.logger import logger
from src.models.claude import ClaudeResponse
class ResponseNormalizer:
"""响应标准化器 - 用于标准模式下验证和补全响应字段"""
@staticmethod
def normalize_claude_response(
response_data: Dict[str, Any], request_id: Optional[str] = None
) -> Dict[str, Any]:
"""
标准化 Claude API 响应
Args:
response_data: 原始响应数据
request_id: 请求ID用于日志
Returns:
标准化后的响应数据(失败时返回原始数据)
"""
if "error" in response_data:
logger.debug(f"[ResponseNormalizer] 检测到错误响应,跳过标准化 | ID:{request_id}")
return response_data
try:
validated = ClaudeResponse.model_validate(response_data)
normalized = validated.model_dump(mode="json", exclude_none=False)
logger.debug(f"[ResponseNormalizer] 响应标准化成功 | ID:{request_id}")
return normalized
except Exception as e:
logger.debug(f"[ResponseNormalizer] 响应验证失败,透传原始数据 | ID:{request_id}")
return response_data
@staticmethod
def should_normalize(response_data: Dict[str, Any]) -> bool:
"""
判断是否需要进行标准化
Args:
response_data: 响应数据
Returns:
是否需要标准化
"""
# 错误响应不需要标准化
if "error" in response_data:
return False
# 已经包含新字段的响应不需要再次标准化
if "context_management" in response_data and "container" in response_data:
return False
return True

View File

@@ -5,6 +5,10 @@
- 使用滑动窗口采样,容忍并发波动
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
AIMD 参数说明:
- 扩容:加性增加 (+INCREASE_STEP)
- 缩容:乘性减少 (*DECREASE_MULTIPLIER默认 0.85)
"""
from datetime import datetime, timezone
@@ -34,7 +38,7 @@ class AdaptiveConcurrencyManager:
核心算法:基于滑动窗口利用率的 AIMD
- 滑动窗口记录最近 N 次请求的利用率
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
- 遇到 429 错误时乘性减少 (*0.7)
- 遇到 429 错误时乘性减少 (*0.85)
- 长时间无 429 且有流量时触发探测性扩容
扩容条件(满足任一即可):

View File

@@ -11,7 +11,6 @@
import asyncio
import math
import os
from contextlib import asynccontextmanager
from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple
@@ -40,6 +39,7 @@ class ConcurrencyManager:
self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {}
self._owns_redis: bool = False
self._memory_initialized = True
async def initialize(self) -> None:
@@ -47,41 +47,29 @@ class ConcurrencyManager:
if self._redis is not None:
return
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
if not redis_url:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
redis_password = os.getenv("REDIS_PASSWORD")
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
# 测试连接
await self._redis.ping()
# 脱敏显示(隐藏密码)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
logger.info(f"[OK] Redis 连接成功: {safe_url}")
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
from src.clients.redis_client import get_redis_client
self._redis = await get_redis_client(require_redis=False)
self._owns_redis = False
if self._redis:
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
else:
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
except Exception as e:
logger.error(f"[ERROR] Redis 连接失败: {e}")
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
self._redis = None
self._owns_redis = False
async def close(self) -> None:
"""关闭 Redis 连接"""
if self._redis:
if self._redis and self._owns_redis:
await self._redis.close()
self._redis = None
logger.info("Redis 连接已关闭")
logger.info("ConcurrencyManager Redis 连接已关闭")
self._redis = None
self._owns_redis = False
def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key"""

View File

@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
@@ -72,11 +72,7 @@ class RPMLimiter:
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = (
window_start.replace(minute=window_start.minute + 1)
if window_start.minute < 59
else window_start.replace(hour=window_start.hour + 1, minute=0)
)
window_end = window_start + timedelta(minutes=1)
# 查找或创建追踪记录
tracking = (

View File

@@ -289,11 +289,17 @@ class RequestResult:
status_code = 500
error_type = "internal_error"
# 构建错误消息,包含上游响应信息
error_message = str(exception)
if isinstance(exception, ProviderNotAvailableException):
if exception.upstream_response:
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
return cls(
status=RequestStatus.FAILED,
metadata=metadata,
status_code=status_code,
error_message=str(exception),
error_message=error_message,
error_type=error_type,
response_time_ms=response_time_ms,
is_stream=is_stream,

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import get_db
from src.models.database import AuditEventType, AuditLog
from src.utils.transaction_manager import transactional
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
class AuditService:
"""审计服务"""
"""审计服务
事务策略:本服务不负责事务提交,由中间件统一管理。
所有方法只做 db.add/flush提交由请求结束时的中间件处理。
"""
@staticmethod
@transactional(commit=False) # 不自动提交,让调用方决定
def log_event(
db: Session,
event_type: AuditEventType,
@@ -54,47 +56,44 @@ class AuditService:
Returns:
审计日志记录
Note:
不在此方法内提交事务,由调用方或中间件统一管理。
"""
try:
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
db.add(audit_log)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(audit_log)
db.add(audit_log)
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
db.flush()
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
db.rollback()
return None
return audit_log
@staticmethod
def log_login_attempt(

View File

@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import create_session
from src.models.database import Usage
from src.models.database import AuditLog, Usage
from src.services.system.config import SystemConfigService
from src.services.system.scheduler import get_scheduler
from src.services.system.stats_aggregator import StatsAggregatorService
@@ -35,6 +35,7 @@ class CleanupScheduler:
def __init__(self):
self.running = False
self._interval_tasks = []
self._stats_aggregation_lock = asyncio.Lock()
async def start(self):
"""启动调度器"""
@@ -56,6 +57,14 @@ class CleanupScheduler:
job_id="stats_aggregation",
name="统计数据聚合",
)
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
scheduler.add_interval_job(
self._scheduled_stats_aggregation,
minutes=30,
job_id="stats_aggregation_backfill",
name="统计数据聚合补偿",
backfill=True,
)
# 清理任务 - 凌晨 3 点执行
scheduler.add_cron_job(
@@ -82,6 +91,15 @@ class CleanupScheduler:
name="Pending状态清理",
)
# 审计日志清理 - 凌晨 4 点执行
scheduler.add_cron_job(
self._scheduled_audit_cleanup,
hour=4,
minute=0,
job_id="audit_cleanup",
name="审计日志清理",
)
# 启动时执行一次初始化任务
asyncio.create_task(self._run_startup_tasks())
@@ -115,9 +133,9 @@ class CleanupScheduler:
# ========== 任务函数APScheduler 直接调用异步函数) ==========
async def _scheduled_stats_aggregation(self):
async def _scheduled_stats_aggregation(self, backfill: bool = False):
"""统计聚合任务(定时调用)"""
await self._perform_stats_aggregation()
await self._perform_stats_aggregation(backfill=backfill)
async def _scheduled_cleanup(self):
"""清理任务(定时调用)"""
@@ -136,6 +154,10 @@ class CleanupScheduler:
"""Pending 清理任务(定时调用)"""
await self._perform_pending_cleanup()
async def _scheduled_audit_cleanup(self):
"""审计日志清理任务(定时调用)"""
await self._perform_audit_cleanup()
# ========== 实际任务实现 ==========
async def _perform_stats_aggregation(self, backfill: bool = False):
@@ -144,136 +166,157 @@ class CleanupScheduler:
Args:
backfill: 是否回填历史数据(启动时检查缺失的日期)
"""
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
if self._stats_aggregation_lock.locked():
logger.info("统计聚合任务正在运行,跳过本次触发")
return
logger.info("开始执行统计数据聚合...")
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
async with self._stats_aggregation_lock:
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = (
db.query(StatsDaily)
.order_by(StatsDaily.date.desc())
.first()
)
logger.info("开始执行统计数据聚合...")
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = (today_local.date() - timedelta(days=1))
missing_start_date = latest_business_date + timedelta(days=1)
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if missing_start_date <= yesterday_business_date:
missing_days = (yesterday_business_date - missing_start_date).days + 1
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
return
current_date = missing_start_date
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
while current_date <= yesterday_business_date:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = today_local.date() - timedelta(days=1)
missing_start_date = latest_business_date + timedelta(days=1)
if missing_start_date <= yesterday_business_date:
missing_days = (
yesterday_business_date - missing_start_date
).days + 1
# 限制最大回填天数,防止停机很久后一次性回填太多
max_backfill_days: int = SystemConfigService.get_config(
db, "max_stats_backfill_days", 30
) or 30
if missing_days > max_backfill_days:
logger.warning(
f"缺失 {missing_days} 天数据超过最大回填限制 "
f"{max_backfill_days} 天,只回填最近 {max_backfill_days}"
)
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
# 聚合用户数据
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
missing_start_date = yesterday_business_date - timedelta(
days=max_backfill_days - 1
)
missing_days = max_backfill_days
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
)
current_date = missing_start_date
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
while current_date <= yesterday_business_date:
try:
db.rollback()
except Exception:
pass
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
)
StatsAggregatorService.aggregate_daily_stats(
db, current_date_local
)
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
try:
db.rollback()
except Exception:
pass
current_date += timedelta(days=1)
current_date += timedelta(days=1)
# 更新全局汇总
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
# 定时任务:聚合昨天的数据
# 注意aggregate_daily_stats 期望业务时区的日期,不是 UTC
yesterday_local = today_local - timedelta(days=1)
# 定时任务:聚合昨天的数据
yesterday_local = today_local - timedelta(days=1)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
# 聚合所有用户的昨日数据
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
# 回滚当前用户的失败操作,继续处理其他用户
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
db.rollback()
except Exception:
pass
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, yesterday_local
)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
try:
db.rollback()
except Exception:
pass
# 更新全局汇总
StatsAggregatorService.update_summary(db)
StatsAggregatorService.update_summary(db)
logger.info("统计数据聚合完成")
logger.info("统计数据聚合完成")
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
db.rollback()
finally:
db.close()
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_pending_cleanup(self):
"""执行 pending 状态清理"""
@@ -300,6 +343,70 @@ class CleanupScheduler:
finally:
db.close()
async def _perform_audit_cleanup(self):
"""执行审计日志清理任务"""
db = create_session()
try:
# 检查是否启用自动清理
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
logger.info("自动清理已禁用,跳过审计日志清理")
return
# 获取审计日志保留天数(默认 30 天,最少 7 天)
audit_retention_days = max(
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
7, # 最少保留 7 天,防止误配置删除所有审计日志
)
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
total_deleted = 0
while True:
# 先查询要删除的记录 ID分批
records_to_delete = (
db.query(AuditLog.id)
.filter(AuditLog.created_at < cutoff_time)
.limit(batch_size)
.all()
)
if not records_to_delete:
break
record_ids = [r.id for r in records_to_delete]
# 执行删除
result = db.execute(
delete(AuditLog)
.where(AuditLog.id.in_(record_ids))
.execution_options(synchronize_session=False)
)
rows_deleted = result.rowcount
db.commit()
total_deleted += rows_deleted
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted}")
await asyncio.sleep(0.1)
if total_deleted > 0:
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
else:
logger.info("无需清理的审计日志")
except Exception as e:
logger.exception(f"审计日志清理失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_cleanup(self):
"""执行清理任务"""
db = create_session()

View File

@@ -12,7 +12,6 @@ from src.core.logger import logger
from src.models.database import Provider, SystemConfig
class LogLevel(str, Enum):
"""日志记录级别"""
@@ -71,6 +70,10 @@ class SystemConfigService:
"value": "provider",
"description": "优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)",
},
"scheduling_mode": {
"value": "cache_affinity",
"description": "调度模式fixed_order(固定顺序模式,严格按优先级顺序) 或 cache_affinity(缓存亲和模式优先使用已缓存的Provider)",
},
"auto_delete_expired_keys": {
"value": False,
"description": "是否自动删除过期的API KeyTrue=物理删除False=仅禁用),仅管理员可配置",
@@ -90,6 +93,35 @@ class SystemConfigService:
return default
@classmethod
def get_configs(cls, db: Session, keys: List[str]) -> Dict[str, Any]:
"""
批量获取系统配置值
Args:
db: 数据库会话
keys: 配置键列表
Returns:
配置键值字典
"""
result = {}
# 一次查询获取所有配置
configs = db.query(SystemConfig).filter(SystemConfig.key.in_(keys)).all()
config_map = {c.key: c.value for c in configs}
# 填充结果,不存在的使用默认值
for key in keys:
if key in config_map:
result[key] = config_map[key]
elif key in cls.DEFAULT_CONFIGS:
result[key] = cls.DEFAULT_CONFIGS[key]["value"]
else:
result[key] = None
return result
@staticmethod
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
"""设置系统配置值"""
@@ -107,6 +139,7 @@ class SystemConfigService:
db.commit()
db.refresh(config)
return config
@staticmethod
@@ -149,8 +182,8 @@ class SystemConfigService:
for config in configs
]
@staticmethod
def delete_config(db: Session, key: str) -> bool:
@classmethod
def delete_config(cls, db: Session, key: str) -> bool:
"""删除系统配置"""
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:

View File

@@ -56,65 +56,44 @@ class StatsAggregatorService:
"""统计数据聚合服务"""
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
# 将业务日期转换为 UTC 时间范围
def compute_daily_stats(db: Session, date: datetime) -> dict:
"""计算指定业务日期的统计数据(不写入数据库)"""
day_start, day_end = _get_business_day_range(date)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 基础请求统计
base_query = db.query(Usage).filter(
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
)
total_requests = base_query.count()
# 如果没有请求,直接返回空记录
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
stats.actual_total_cost = 0.0
stats.input_cost = 0.0
stats.output_cost = 0.0
stats.cache_creation_cost = 0.0
stats.cache_read_cost = 0.0
stats.avg_response_time_ms = 0.0
stats.fallback_count = 0
return {
"day_start": day_start,
"total_requests": 0,
"success_requests": 0,
"error_requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
"total_cost": 0.0,
"actual_total_cost": 0.0,
"input_cost": 0.0,
"output_cost": 0.0,
"cache_creation_cost": 0.0,
"cache_read_cost": 0.0,
"avg_response_time_ms": 0.0,
"fallback_count": 0,
"unique_models": 0,
"unique_providers": 0,
}
if not existing:
db.add(stats)
db.commit()
return stats
# 错误请求数
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
# Token 和成本聚合
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
@@ -157,7 +136,6 @@ class StatsAggregatorService:
or 0
)
# 使用维度统计
unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
@@ -171,31 +149,74 @@ class StatsAggregatorService:
or 0
)
return {
"day_start": day_start,
"total_requests": total_requests,
"success_requests": total_requests - error_requests,
"error_requests": error_requests,
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
"fallback_count": fallback_count,
"unique_models": unique_models,
"unique_providers": unique_providers,
}
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
computed = StatsAggregatorService.compute_daily_stats(db, date)
day_start = computed["day_start"]
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 更新统计记录
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
stats.input_cost = float(aggregated.input_cost or 0)
stats.output_cost = float(aggregated.output_cost or 0)
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
stats.fallback_count = fallback_count
stats.unique_models = unique_models
stats.unique_providers = unique_providers
stats.total_requests = computed["total_requests"]
stats.success_requests = computed["success_requests"]
stats.error_requests = computed["error_requests"]
stats.input_tokens = computed["input_tokens"]
stats.output_tokens = computed["output_tokens"]
stats.cache_creation_tokens = computed["cache_creation_tokens"]
stats.cache_read_tokens = computed["cache_read_tokens"]
stats.total_cost = computed["total_cost"]
stats.actual_total_cost = computed["actual_total_cost"]
stats.input_cost = computed["input_cost"]
stats.output_cost = computed["output_cost"]
stats.cache_creation_cost = computed["cache_creation_cost"]
stats.cache_read_cost = computed["cache_read_cost"]
stats.avg_response_time_ms = computed["avg_response_time_ms"]
stats.fallback_count = computed["fallback_count"]
stats.unique_models = computed["unique_models"]
stats.unique_providers = computed["unique_providers"]
if not existing:
db.add(stats)
db.commit()
# 日志使用业务日期(输入参数),而不是 UTC 日期
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
return stats
@staticmethod

File diff suppressed because it is too large Load Diff

View File

@@ -457,7 +457,7 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming
# 更新状态为 streaming,同时更新 provider
if self.request_id:
try:
from src.services.usage.service import UsageService
@@ -465,6 +465,7 @@ class StreamUsageTracker:
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")

View File

@@ -210,7 +210,15 @@ class ApiKeyService:
@staticmethod
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
"""检查速率限制"""
"""检查速率限制
Returns:
(is_allowed, remaining): 是否允许请求,剩余可用次数
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
"""
# 如果 rate_limit 为 None表示不限制
if api_key.rate_limit is None:
return True, -1 # -1 表示无限制
# 计算时间窗口
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)

View File

@@ -71,8 +71,8 @@ class PreferenceService:
raise NotFoundException("Provider not found or inactive")
preferences.default_provider_id = default_provider_id
if theme is not None:
if theme not in ["light", "dark", "auto"]:
raise ValueError("Invalid theme. Must be 'light', 'dark', or 'auto'")
if theme not in ["light", "dark", "auto", "system"]:
raise ValueError("Invalid theme. Must be 'light', 'dark', 'auto', or 'system'")
preferences.theme = theme
if language is not None:
preferences.language = language

View File

@@ -3,6 +3,7 @@
提供统一的用户认证和授权功能
"""
import hashlib
from typing import Optional
from fastapi import Depends, Header, HTTPException, status
@@ -41,13 +42,20 @@ async def get_current_user(
try:
# 验证Token格式和签名
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.error(
"Token验证失败: {}: {}, token_fp={}",
token_error.status_code,
token_error.detail,
token_fp,
)
raise # 重新抛出原始异常,保持状态码
except Exception as token_error:
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
raise ForbiddenException("无效的Token")
user_id = payload.get("user_id")
@@ -63,7 +71,8 @@ async def get_current_user(
raise ForbiddenException("无效的认证凭据")
# 仅在DEBUG模式下记录详细信息
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
# 确保user_id是字符串格式UUID
if not isinstance(user_id, str):
@@ -144,7 +153,7 @@ async def get_current_user_from_header(
token = authorization.replace("Bearer ", "")
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:

View File

@@ -7,29 +7,47 @@ from typing import Optional
from fastapi import Request
from src.config import config
def get_client_ip(request: Request) -> str:
"""
获取客户端真实IP地址
按优先级检查:
1. X-Forwarded-For 头(支持代理链)
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取
2. X-Real-IP 头Nginx 代理)
3. 直接客户端IP
安全说明:
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
Args:
request: FastAPI Request 对象
Returns:
str: 客户端IP地址如果无法获取则返回 "unknown"
"""
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,直接返回连接 IP
if trusted_proxy_count == 0:
if request.client and request.client.host:
return request.client.host
return "unknown"
# 优先检查 X-Forwarded-For 头(可能包含代理链)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
client_ip = forwarded_for.split(",")[0].strip()
if client_ip:
return client_ip
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP 头(通常由 Nginx 设置)
real_ip = request.headers.get("X-Real-IP")
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
}
def extract_ip_from_headers(headers: dict) -> str:
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
"""
从HTTP头字典中提取IP地址用于中间件等场景
Args:
headers: HTTP头字典
trusted_proxy_count: 可信代理层数None 时使用配置值
Returns:
str: 客户端IP地址
"""
if trusted_proxy_count is None:
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,返回 unknown调用方需要用其他方式获取连接 IP
if trusted_proxy_count == 0:
return "unknown"
# 检查 X-Forwarded-For
forwarded_for = headers.get("x-forwarded-for", "")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP
real_ip = headers.get("x-real-ip", "")

Some files were not shown because too many files have changed in this diff Show More