49 Commits

Author SHA1 Message Date
fawney19
6aa1876955 feat: add Dockerfile.app.local for China mirror support 2025-12-19 16:20:02 +08:00
fawney19
7f07122aea refactor: separate frontend build from base image for faster incremental builds 2025-12-19 16:02:38 +08:00
fawney19
c2ddc6bd3c refactor: optimize Docker build with multi-stage and slim runtime base image 2025-12-19 15:51:21 +08:00
fawney19
af476ff21e feat: enhance error logging and upstream response tracking for provider failures 2025-12-19 15:29:48 +08:00
fawney19
3bbc1c6b66 feat: add provider compatibility error detection for intelligent failover
- Introduce ProviderCompatibilityException for unsupported parameter/feature errors
- Add COMPATIBILITY_ERROR_PATTERNS to detect provider-specific limitations
- Implement _is_compatibility_error() method in ErrorClassifier
- Prioritize compatibility error checking before client error validation
- Remove 'max_tokens' from CLIENT_ERROR_PATTERNS as it can indicate compatibility issues
- Enable automatic failover when provider doesn't support requested features
- Improve error classification accuracy with pattern matching for common compatibility issues
2025-12-19 13:28:26 +08:00
fawney19
c69a0a8506 refactor: remove stream smoothing config from system settings and improve base image caching
- Remove stream_smoothing configuration from SystemConfigService (moved to handler default)
- Remove stream smoothing UI controls from admin settings page
- Add AdminClearSingleAffinityAdapter for targeted cache invalidation
- Add clearSingleAffinity API endpoint to clear specific affinity cache entries
- Include global_model_id in affinity list response for UI deletion support
- Improve CI/CD workflow with hash-based base image change detection
- Add hash label to base image for reliable cache invalidation detection
- Use remote image inspection to determine if base image rebuild is needed
- Include Dockerfile.base in hash calculation for proper dependency tracking
2025-12-19 13:09:56 +08:00
fawney19
1fae202bde Merge pull request #30 from AAEE86/master
chore: Modify the order of API format enumeration
2025-12-19 12:34:22 +08:00
fawney19
b9a26c4550 fix: add SETUPTOOLS_SCM_PRETEND_VERSION for CI builds 2025-12-19 12:01:19 +08:00
AAEE86
e42bd35d48 chore: Modify the order of API format enumeration - Move CLAUDE_CLI before OPENAI 2025-12-19 11:44:10 +08:00
fawney19
f22a073fd9 fix: rebuild app image when base image changes during deployment
- Track BASE_REBUILT flag to detect base image rebuilds
- Force app image rebuild when base image is rebuilt
- Prevents stale app images built with outdated base images
- Ensures consistent deployment when base dependencies change
2025-12-19 11:32:43 +08:00
fawney19
5c7ad089d2 fix: disable nginx buffering for streaming responses
- Add X-Accel-Buffering: no header to prevent nginx from buffering streamed content
- Ensures immediate delivery of each chunk without proxy buffering delays
- Improves real-time streaming performance and responsiveness
- Applies to both production and local Dockerfiles
2025-12-19 11:26:15 +08:00
fawney19
97425ac68f refactor: make stream smoothing parameters configurable and add models cache invalidation
- Move stream smoothing parameters (chunk_size, delay_ms) to database config
- Remove hardcoded stream smoothing constants from StreamProcessor
- Simplify dynamic delay calculation by using config values directly
- Add invalidate_models_list_cache() function to clear /v1/models endpoint cache
- Call cache invalidation on model create, update, delete, and bulk operations
- Update admin UI to allow runtime configuration of smoothing parameters
- Improve model listing freshness when models are modified
2025-12-19 11:03:46 +08:00
fawney19
912f6643e2 tune: adjust stream smoothing parameters for better user experience
- Increase chunk size from 5 to 20 characters for fewer delays
- Reduce min delay from 15ms to 8ms for faster playback
- Reduce max delay from 24ms to 15ms for better responsiveness
- Adjust text thresholds to better differentiate content types
- Apply parameter tuning to both StreamProcessor and _LightweightSmoother
2025-12-19 09:51:09 +08:00
fawney19
6c0373fda6 refactor: simplify text splitting logic in stream processor
- Remove complex conditional logic for short/medium/long text differentiation
- Unify text splitting to always use consistent CHUNK_SIZE-based splitting
- Rely on dynamic delay calculation for output speed adjustment
- Reduce code complexity in both main smoother and lightweight smoother
2025-12-19 09:48:11 +08:00
fawney19
070121717d refactor: consolidate stream smoothing into StreamProcessor with intelligent timing
- Move StreamSmoother functionality directly into StreamProcessor for better integration
- Create ContentExtractor strategy pattern for format-agnostic content extraction
- Implement intelligent dynamic delay calculation based on text length
- Support three text length tiers: short (char-by-char), medium (chunked), long (chunked)
- Remove manual chunk_size and delay_ms configuration - now auto-calculated
- Simplify admin UI to single toggle switch with auto timing adjustment
- Extract format detection logic to reusable content_extractors module
- Improve code maintainability with cleaner architecture
2025-12-19 09:46:22 +08:00
fawney19
85fafeacb8 feat: add stream smoothing feature for improved user experience
- Implement StreamSmoother class to split large content chunks into smaller pieces with delay
- Support OpenAI, Claude, and Gemini API response formats for smooth playback
- Add stream smoothing configuration to system settings (enable, chunk size, delay)
- Create streamlined API for stream smoothing with StreamSmoothingConfig dataclass
- Add admin UI controls for configuring stream smoothing parameters
- Use batch configuration loading to minimize database queries
- Enable typing effect simulation for better user experience in streaming responses
2025-12-19 03:15:19 +08:00
fawney19
daf8b870f0 fix: include Dockerfile.base.local in dependency hash calculation
- Add Dockerfile.base.local to deps hash to detect Docker configuration changes
- Ensures deployment rebuilds when nginx proxy settings are modified
- Prevents stale Docker images from being reused after config changes
2025-12-19 02:38:46 +08:00
fawney19
880fb61c66 fix: disable gzip compression in nginx proxy configuration
- Add gzip off directive to prevent nginx from compressing proxied responses
- Ensures stream integrity for chunked transfer encoding
- Applies to both production and local Dockerfiles
2025-12-19 02:17:07 +08:00
fawney19
7e792dabfc refactor: use background task for client disconnection monitoring
- Replace time-based throttling with background task for disconnect checks
- Remove time.monotonic() and related throttling logic
- Prevent blocking of stream transmission during connection checks
- Properly clean up background task with try/finally block
- Improve throughput and responsiveness of stream processing
2025-12-19 01:59:56 +08:00
fawney19
cd06169b2f fix: detect OpenAI format stream completion via finish_reason
- Add detection of finish_reason in OpenAI API responses to mark stream completion
- Ensures OpenAI API streams are properly marked as complete even without explicit completion events
- Complements existing completion event detection for other API formats
2025-12-19 01:44:35 +08:00
fawney19
50ffd47546 fix: handle client disconnection after stream completion gracefully
- Check has_completion flag before marking client disconnection as failure
- Allow graceful termination if response already completed when client disconnects
- Change logging level to info for post-completion disconnections
- Prevent false error reporting when client closes connection after receiving full response
2025-12-19 01:36:20 +08:00
fawney19
5f0c1fb347 refactor: remove unused response normalizer module
- Delete unused ResponseNormalizer class and its initialization logic
- Remove response_normalizer and enable_response_normalization parameters from handlers
- Simplify chat adapter base initialization by removing normalizer setup
- Clean up unused imports in handler modules
2025-12-19 01:20:30 +08:00
fawney19
7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00
fawney19
c7b971cfe7 Merge pull request #29 from fawney19/feature/proxy-support
feat: add HTTP/SOCKS5 proxy support for API endpoints
2025-12-18 16:18:25 +08:00
fawney19
293bb592dc fix: enhance proxy configuration with password preservation and UI improvements
- Add 'enabled' field to ProxyConfig for preserving config when disabled
- Mask proxy password in API responses (return '***' instead of actual password)
- Preserve existing password on update when new password not provided
- Add URL encoding for proxy credentials (handle special chars like @, :, /)
- Enhanced URL validation: block SOCKS4, require valid host, forbid embedded auth
- UI improvements: use Switch component, dynamic password placeholder
- Add confirmation dialog for orphaned credentials (URL empty but has username/password)
- Prevent browser password autofill with randomized IDs and CSS text-security
- Unify ProxyConfig type definition in types.ts
2025-12-18 16:14:37 +08:00
fawney19
3e50c157be feat: add HTTP/SOCKS5 proxy support for API endpoints
- Add proxy field to ProviderEndpoint database model with migration
- Add ProxyConfig Pydantic model for proxy URL validation
- Extend HTTP client pool with create_client_with_proxy method
- Integrate proxy configuration in chat_handler_base.py and cli_handler_base.py
- Update admin API endpoints to support proxy configuration CRUD
- Add proxy configuration UI in frontend EndpointFormDialog

Fixes #28
2025-12-18 14:46:47 +08:00
fawney19
21587449c8 fix: improve error classification and logging system
- Enhance error classifier to properly handle API key failures with fallback support
- Add error reason/code parsing for better AWS and multi-provider compatibility
- Improve error message structure detection for non-standard formats
- Refactor file logging with size-based rotation (100MB) instead of daily
- Optimize production logging by disabling backtrace and diagnose
- Clean up model validation and remove redundant configurations
2025-12-18 10:57:31 +08:00
fawney19
3d0ab353d3 refactor: migrate Pydantic Config to v2 ConfigDict 2025-12-18 02:20:53 +08:00
fawney19
b2a857c164 refactor: consolidate transaction management and remove legacy modules
- Remove unused context.py module (replaced by request.state)
- Remove provider_cache.py (no longer needed)
- Unify environment loading in config/settings.py instead of __init__.py
- Add deprecation warning for get_async_db() (consolidating on sync Session)
- Enhance database.py documentation with comprehensive transaction strategy
- Simplify audit logging to reuse request-level Session (no separate connections)
- Extract UsageService._build_usage_params() helper to reduce code duplication
- Update model and user cache implementations with refined transaction handling
- Remove unnecessary sessionmaker from pipeline
- Clean up audit service exception handling
2025-12-18 01:59:40 +08:00
fawney19
4d1d863916 refactor: improve authentication and user data handling
- Replace user cache queries with direct database queries to ensure data consistency
- Fix token_type parameter in verify_token calls (access token verification)
- Fix role-based permission check using dictionary ranking instead of string comparison
- Fix logout operation to use correct JWT claim name (user_id instead of sub)
- Simplify user authentication flow by removing unnecessary cache layer
- Optimize session initialization in main.py using create_session helper
- Remove unused imports and exception variables
2025-12-18 01:09:22 +08:00
fawney19
b579420690 refactor: optimize database session lifecycle and middleware architecture
- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
2025-12-18 00:35:46 +08:00
fawney19
9d5c84f9d3 refactor: add scheduling mode support and optimize system settings UI
- Add fixed_order and cache_affinity scheduling modes to CacheAwareScheduler
- Only apply cache affinity in cache_affinity mode; use fixed order otherwise
- Simplify Dialog components with title/description props
- Remove unnecessary button shadows in SystemSettings
- Optimize import dialog UI structure
- Update ModelAliasesTab shadow styling
- Fix fallback orchestrator type hints
- Add scheduling_mode configuration in system config
2025-12-17 19:15:08 +08:00
fawney19
53e6a82480 Merge pull request #25 from fawney19/fix/22-dark-mode-settings-bug
fix: 修复个人设置页面深色模式切换后刷新失效的问题
2025-12-17 18:11:53 +08:00
fawney19
bd11ebdbd5 fix: 修复个人设置页面深色模式切换后刷新失效的问题
- 前端使用 useDarkMode composable 统一主题切换逻辑
- 后端支持 system 主题值(之前只支持 auto)
- 主题以本地 localStorage 为准,避免刷新时被服务端旧值覆盖

Fixes #22
2025-12-17 18:02:19 +08:00
fawney19
1dac4cb156 refactor: optimize provider query and stats aggregation logic 2025-12-17 16:41:10 +08:00
fawney19
50abb55c94 fix(models): clear form state when loading model data for edit
Reset model selection, search query, and expanded provider state
when switching to edit mode to prevent stale UI state carrying over
from previous operations. Also ensure tieredPricing is properly set
or reset based on model data.
2025-12-16 18:42:58 +08:00
fawney19
73d3c9d3e4 ui(models): display model ID in global model form dialog
Show model ID below model name in the dropdown list for better clarity
when selecting models, with appropriate text styling for selected state.
2025-12-16 18:36:23 +08:00
fawney19
d24c3885ab feat(admin): add config and user data import/export functionality
Add comprehensive import/export endpoints for:
- Provider and model configuration (with key decryption for export)
- User data and API keys (preserving encrypted data)

Includes merge modes (skip/overwrite/error) for conflict handling,
10MB size limit for imports, and automatic cache invalidation.

Also fix optional field in GlobalModelResponse tiered_pricing.
2025-12-16 18:33:14 +08:00
fawney19
d696c575e6 refactor(migrations): add idempotency checks to migration scripts 2025-12-16 17:46:38 +08:00
fawney19
46ff5a1a50 refactor(models): enhance model management with official provider marking and extended metadata
- Add OFFICIAL_PROVIDERS set to mark first-party vendors in models.dev
- Implement official provider marking function with cache compatibility
- Extend model metadata with family, context_limit, output_limit fields
- Improve frontend model selection UI with wider panel and better search
- Add dark mode support for provider logos
- Optimize scrollbar styling for model lists
- Update deployment documentation with clearer migration steps
2025-12-16 17:28:40 +08:00
fawney19
edce43d45f fix(auth): make get_current_user and get_current_user_from_header async functions
将 get_current_user 和 get_current_user_from_header 函数声明为 async,
并更新 AuthService.verify_token 的调用为 await,以正确处理异步 Token 验证。
2025-12-16 13:42:26 +08:00
fawney19
33265b4b13 refactor(global-model): migrate model metadata to flexible config structure
将模型配置从多个固定字段(description, official_url, icon_url, default_supports_* 等)
统一为灵活的 config JSON 字段,提高扩展性。同时优化前端模型创建表单,支持从 models-dev
列表直接选择模型快速填充。

主要变更:
- 后端:模型表迁移,支持 config JSON 存储模型能力和元信息
- 前端:GlobalModelFormDialog 支持两种创建方式(列表选择/手动填写)
- API 类型更新,对齐新的数据结构
2025-12-16 12:21:21 +08:00
fawney19
a94aeca2d3 docs(deploy): add database migration step to deployment guide and create migration script 2025-12-16 09:21:24 +08:00
fawney19
c42ebdd0ee test(handler): add comprehensive stream processor unit tests 2025-12-16 02:40:26 +08:00
fawney19
f1e3c2ab11 feat(frontend-usage): enhance usage UI with first byte latency metrics
- Update usage records table to display first_byte_time_ms metrics
- Improve request timeline visualization for latency tracking
- Extend usage types for new timing information
2025-12-16 02:39:54 +08:00
fawney19
4e2ba0e57f feat(usage): add first_byte_time_ms tracking to usage statistics
- Enhance usage service to capture and store first byte latency metrics
- Update usage API routes to include new timing information
2025-12-16 02:39:36 +08:00
fawney19
a3df41d63d refactor(cli-handler): improve stream handling and response processing
- Refactor CLI handler base for better stream context management
- Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters
- Enhance telemetry tracking across CLI handlers
2025-12-16 02:39:20 +08:00
fawney19
ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00
fawney19
9b496abb73 feat(db): add first_byte_time_ms column to usage table 2025-12-16 02:38:43 +08:00
139 changed files with 10143 additions and 3622 deletions

View File

@@ -15,6 +15,8 @@ env:
REGISTRY: ghcr.io
BASE_IMAGE_NAME: fawney19/aether-base
APP_IMAGE_NAME: fawney19/aether
# Files that affect base image - used for hash calculation
BASE_FILES: "Dockerfile.base pyproject.toml frontend/package.json frontend/package-lock.json"
jobs:
check-base-changes:
@@ -23,8 +25,13 @@ jobs:
base_changed: ${{ steps.check.outputs.base_changed }}
steps:
- uses: actions/checkout@v4
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
fetch-depth: 2
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Check if base image needs rebuild
id: check
@@ -34,10 +41,26 @@ jobs:
exit 0
fi
# Check if base-related files changed
if git diff --name-only HEAD~1 HEAD | grep -qE '^(Dockerfile\.base|pyproject\.toml|frontend/package.*\.json)$'; then
# Calculate current hash of base-related files
CURRENT_HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
echo "Current base files hash: $CURRENT_HASH"
# Try to get hash label from remote image config
# Pull the image config and extract labels
REMOTE_HASH=""
if docker pull ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest 2>/dev/null; then
REMOTE_HASH=$(docker inspect ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest --format '{{ index .Config.Labels "org.opencontainers.image.base.hash" }}' 2>/dev/null) || true
fi
if [ -z "$REMOTE_HASH" ] || [ "$REMOTE_HASH" == "<no value>" ]; then
# No remote image or no hash label, need to rebuild
echo "No remote base image or hash label found, need rebuild"
echo "base_changed=true" >> $GITHUB_OUTPUT
elif [ "$CURRENT_HASH" != "$REMOTE_HASH" ]; then
echo "Hash mismatch: remote=$REMOTE_HASH, current=$CURRENT_HASH"
echo "base_changed=true" >> $GITHUB_OUTPUT
else
echo "Hash matches, no rebuild needed"
echo "base_changed=false" >> $GITHUB_OUTPUT
fi
@@ -61,6 +84,12 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Calculate base files hash
id: hash
run: |
HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
echo "hash=$HASH" >> $GITHUB_OUTPUT
- name: Extract metadata for base image
id: meta
uses: docker/metadata-action@v5
@@ -69,6 +98,8 @@ jobs:
tags: |
type=raw,value=latest
type=sha,prefix=
labels: |
org.opencontainers.image.base.hash=${{ steps.hash.outputs.hash }}
- name: Build and push base image
uses: docker/build-push-action@v5
@@ -117,7 +148,7 @@ jobs:
- name: Update Dockerfile.app to use registry base image
run: |
sed -i "s|FROM aether-base:latest|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest|g" Dockerfile.app
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
- name: Build and push app image
uses: docker/build-push-action@v5

View File

@@ -1,16 +1,134 @@
# 应用镜像:基于基础镜像,只复制代码(秒级构建)
# 运行镜像:从 base 提取产物到精简运行时
# 构建命令: docker build -f Dockerfile.app -t aether-app:latest .
FROM aether-base:latest
# 用于 GitHub Actions CI官方源
FROM aether-base:latest AS builder
WORKDIR /app
# 复制前端源码并构建
COPY frontend/ ./frontend/
RUN cd frontend && npm run build
# ==================== 运行时镜像 ====================
FROM python:3.12-slim
WORKDIR /app
# 运行时依赖(无 gcc/nodejs/npm
RUN apt-get update && apt-get install -y \
nginx \
supervisor \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/*
# 从 base 镜像复制 Python 包
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
# 只复制需要的 Python 可执行文件
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
# 从 builder 阶段复制前端构建产物
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
# 复制后端代码
COPY src/ ./src/
COPY alembic.ini ./
COPY alembic/ ./alembic/
# 构建前端(使用基础镜像中已安装的 node_modules
COPY frontend/ /tmp/frontend/
RUN cd /tmp/frontend && npm run build && \
cp -r dist/* /usr/share/nginx/html/ && \
rm -rf /tmp/frontend
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' gzip off;' \
' add_header X-Accel-Buffering no;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

135
Dockerfile.app.local Normal file
View File

@@ -0,0 +1,135 @@
# 运行镜像:从 base 提取产物到精简运行时(国内镜像源版本)
# 构建命令: docker build -f Dockerfile.app.local -t aether-app:latest .
# 用于本地/国内服务器部署
FROM aether-base:latest AS builder
WORKDIR /app
# 复制前端源码并构建
COPY frontend/ ./frontend/
RUN cd frontend && npm run build
# ==================== 运行时镜像 ====================
FROM python:3.12-slim
WORKDIR /app
# 运行时依赖(使用清华镜像源)
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
apt-get update && apt-get install -y \
nginx \
supervisor \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/*
# 从 base 镜像复制 Python 包
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
# 只复制需要的 Python 可执行文件
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
# 从 builder 阶段复制前端构建产物
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
# 复制后端代码
COPY src/ ./src/
COPY alembic.ini ./
COPY alembic/ ./alembic/
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' gzip off;' \
' add_header X-Accel-Buffering no;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -1,122 +1,25 @@
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
# 构建镜像:编译环境 + 预编译的依赖
# 用于 GitHub Actions CI 构建(不使用国内镜像源)
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
FROM python:3.12-slim
WORKDIR /app
# 系统依赖
# 构建工具
RUN apt-get update && apt-get install -y \
nginx \
supervisor \
libpq-dev \
gcc \
curl \
gettext-base \
nodejs \
npm \
&& rm -rf /var/lib/apt/lists/*
# Python 依赖(安装到系统,不用 -e 模式)
# Python 依赖
COPY pyproject.toml README.md ./
RUN mkdir -p src && touch src/__init__.py && \
pip install --no-cache-dir .
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
pip cache purge
# 前端依赖
COPY frontend/package*.json /tmp/frontend/
WORKDIR /tmp/frontend
RUN npm ci
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
WORKDIR /app
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
# 前端依赖(只安装,不构建)
COPY frontend/package*.json ./frontend/
RUN cd frontend && npm ci

View File

@@ -1,18 +1,15 @@
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
# 构建镜像:编译环境 + 预编译的依赖(国内镜像源版本)
# 构建命令: docker build -f Dockerfile.base.local -t aether-base:latest .
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
FROM python:3.12-slim
WORKDIR /app
# 系统依赖
# 构建工具(使用清华镜像源)
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
apt-get update && apt-get install -y \
nginx \
supervisor \
libpq-dev \
gcc \
curl \
gettext-base \
nodejs \
npm \
&& rm -rf /var/lib/apt/lists/*
@@ -20,107 +17,12 @@ RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.li
# pip 镜像源
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# Python 依赖(安装到系统,不用 -e 模式)
# Python 依赖
COPY pyproject.toml README.md ./
RUN mkdir -p src && touch src/__init__.py && \
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir .
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
pip cache purge
# 前端依赖
COPY frontend/package*.json /tmp/frontend/
WORKDIR /tmp/frontend
RUN npm config set registry https://registry.npmmirror.com && npm ci
# Nginx 配置模板
RUN printf '%s\n' \
'server {' \
' listen 80;' \
' server_name _;' \
' root /usr/share/nginx/html;' \
' index index.html;' \
' client_max_body_size 100M;' \
'' \
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
' expires 1y;' \
' add_header Cache-Control "public, no-transform";' \
' try_files $uri =404;' \
' }' \
'' \
' location ~ ^/(src|node_modules)/ {' \
' deny all;' \
' return 404;' \
' }' \
'' \
' location ~ ^/(dashboard|admin|login)(/|$) {' \
' try_files $uri $uri/ /index.html;' \
' }' \
'' \
' location / {' \
' try_files $uri $uri/ @backend;' \
' }' \
'' \
' location @backend {' \
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
' proxy_http_version 1.1;' \
' proxy_set_header Host $host;' \
' proxy_set_header X-Real-IP $remote_addr;' \
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
' proxy_set_header X-Forwarded-Proto $scheme;' \
' proxy_set_header Connection "";' \
' proxy_set_header Accept $http_accept;' \
' proxy_set_header Content-Type $content_type;' \
' proxy_set_header Authorization $http_authorization;' \
' proxy_set_header X-Api-Key $http_x_api_key;' \
' proxy_buffering off;' \
' proxy_cache off;' \
' proxy_request_buffering off;' \
' chunked_transfer_encoding on;' \
' proxy_connect_timeout 600s;' \
' proxy_send_timeout 600s;' \
' proxy_read_timeout 600s;' \
' }' \
'}' > /etc/nginx/sites-available/default.template
# Supervisor 配置
RUN printf '%s\n' \
'[supervisord]' \
'nodaemon=true' \
'logfile=/var/log/supervisor/supervisord.log' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/dev/stdout' \
'stdout_logfile_maxbytes=0' \
'stderr_logfile=/dev/stderr' \
'stderr_logfile_maxbytes=0' \
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
WORKDIR /app
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONIOENCODING=utf-8 \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
PORT=8084
EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
# 前端依赖(只安装,不构建,使用淘宝镜像源)
COPY frontend/package*.json ./frontend/
RUN cd frontend && npm config set registry https://registry.npmmirror.com && npm ci

View File

@@ -60,8 +60,11 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker-compose up -d
# 4. 更新
docker-compose pull && docker-compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker-compose pull && docker-compose up -d && ./migrate.sh
```
### Docker Compose本地构建镜像

View File

@@ -20,10 +20,10 @@ depends_on = None
def upgrade() -> None:
# Create ENUM types
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
# Create ENUM types (with IF NOT EXISTS for idempotency)
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
op.execute(
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
)
# ==================== users ====================
@@ -35,7 +35,7 @@ def upgrade() -> None:
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column(
"role",
sa.Enum("admin", "user", name="userrole", create_type=False),
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
nullable=False,
server_default="user",
),
@@ -67,7 +67,7 @@ def upgrade() -> None:
sa.Column("website", sa.String(500), nullable=True),
sa.Column(
"billing_type",
sa.Enum(
postgresql.ENUM(
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
),
nullable=False,

View File

@@ -26,16 +26,66 @@ branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def table_exists(bind, table_name: str) -> bool:
"""检查表是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_name = :table_name
)
"""
),
{"table_name": table_name},
)
return result.scalar()
def index_exists(bind, index_name: str) -> bool:
"""检查索引是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM pg_indexes
WHERE indexname = :index_name
)
"""
),
{"index_name": index_name},
)
return result.scalar()
def upgrade() -> None:
"""添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表"""
# 1. 添加 provider_model_aliases 字段
op.add_column(
'models',
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
)
# 2. 迁移 model_mappings 数据
bind = op.get_bind()
# 1. 添加 provider_model_aliases 字段(如果不存在)
if not column_exists(bind, "models", "provider_model_aliases"):
op.add_column(
'models',
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
)
# 2. 迁移 model_mappings 数据(如果表存在)
session = Session(bind=bind)
model_mappings_table = sa.table(
@@ -96,104 +146,118 @@ def upgrade() -> None:
# 查询所有活跃的 provider 级别 alias只迁移 is_active=True 且 mapping_type='alias' 的)
# 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析)
mappings = session.execute(
sa.select(
model_mappings_table.c.source_model,
model_mappings_table.c.target_global_model_id,
model_mappings_table.c.provider_id,
)
.where(
model_mappings_table.c.is_active.is_(True),
model_mappings_table.c.provider_id.isnot(None),
model_mappings_table.c.mapping_type == "alias",
)
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
).all()
# 按 (provider_id, target_global_model_id) 分组,收集别名
alias_groups: dict = {}
for source_model, target_global_model_id, provider_id in mappings:
if not isinstance(source_model, str):
continue
source_model = source_model.strip()
if not source_model:
continue
if not isinstance(provider_id, str) or not provider_id:
continue
if not isinstance(target_global_model_id, str) or not target_global_model_id:
continue
key = (provider_id, target_global_model_id)
if key not in alias_groups:
alias_groups[key] = []
priority = len(alias_groups[key]) + 1
alias_groups[key].append({"name": source_model, "priority": priority})
# 更新对应的 models 记录
for (provider_id, global_model_id), aliases in alias_groups.items():
model_row = session.execute(
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
# 仅当 model_mappings 表存在时执行迁移
if table_exists(bind, "model_mappings"):
mappings = session.execute(
sa.select(
model_mappings_table.c.source_model,
model_mappings_table.c.target_global_model_id,
model_mappings_table.c.provider_id,
)
.where(
models_table.c.provider_id == provider_id,
models_table.c.global_model_id == global_model_id,
model_mappings_table.c.is_active.is_(True),
model_mappings_table.c.provider_id.isnot(None),
model_mappings_table.c.mapping_type == "alias",
)
.limit(1)
).first()
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
).all()
if model_row:
model_id = model_row[0]
existing_aliases = normalize_alias_list(model_row[1])
# 按 (provider_id, target_global_model_id) 分组,收集别名
alias_groups: dict = {}
for source_model, target_global_model_id, provider_id in mappings:
if not isinstance(source_model, str):
continue
source_model = source_model.strip()
if not source_model:
continue
if not isinstance(provider_id, str) or not provider_id:
continue
if not isinstance(target_global_model_id, str) or not target_global_model_id:
continue
existing_names = {a["name"] for a in existing_aliases}
merged_aliases = list(existing_aliases)
for alias in aliases:
name = alias.get("name")
if not isinstance(name, str):
continue
name = name.strip()
if not name or name in existing_names:
continue
key = (provider_id, target_global_model_id)
if key not in alias_groups:
alias_groups[key] = []
priority = len(alias_groups[key]) + 1
alias_groups[key].append({"name": source_model, "priority": priority})
merged_aliases.append(
{
"name": name,
"priority": len(merged_aliases) + 1,
}
# 更新对应的 models 记录
for (provider_id, global_model_id), aliases in alias_groups.items():
model_row = session.execute(
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
.where(
models_table.c.provider_id == provider_id,
models_table.c.global_model_id == global_model_id,
)
existing_names.add(name)
.limit(1)
).first()
session.execute(
models_table.update()
.where(models_table.c.id == model_id)
.values(
provider_model_aliases=merged_aliases if merged_aliases else None,
updated_at=datetime.now(timezone.utc),
if model_row:
model_id = model_row[0]
existing_aliases = normalize_alias_list(model_row[1])
existing_names = {a["name"] for a in existing_aliases}
merged_aliases = list(existing_aliases)
for alias in aliases:
name = alias.get("name")
if not isinstance(name, str):
continue
name = name.strip()
if not name or name in existing_names:
continue
merged_aliases.append(
{
"name": name,
"priority": len(merged_aliases) + 1,
}
)
existing_names.add(name)
session.execute(
models_table.update()
.where(models_table.c.id == model_id)
.values(
provider_model_aliases=merged_aliases if merged_aliases else None,
updated_at=datetime.now(timezone.utc),
)
)
)
session.commit()
session.commit()
# 3. 删除 model_mappings 表
op.drop_table('model_mappings')
# 3. 删除 model_mappings 表
op.drop_table('model_mappings')
# 4. 添加索引优化别名解析性能
# provider_model_name 索引(支持精确匹配)
op.create_index(
"idx_model_provider_model_name",
"models",
["provider_model_name"],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# provider_model_name 索引(支持精确匹配,如果不存在
if not index_exists(bind, "idx_model_provider_model_name"):
op.create_index(
"idx_model_provider_model_name",
"models",
["provider_model_name"],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# provider_model_aliases GIN 索引(支持 JSONB 查询,仅 PostgreSQL
if bind.dialect.name == "postgresql":
# 将 json 列转为 jsonbjsonb 性能更好且支持 GIN 索引)
# 使用 IF NOT EXISTS 风格的检查来避免重复转换
op.execute(
"""
ALTER TABLE models
ALTER COLUMN provider_model_aliases TYPE jsonb
USING provider_model_aliases::jsonb
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'models'
AND column_name = 'provider_model_aliases'
AND data_type = 'json'
) THEN
ALTER TABLE models
ALTER COLUMN provider_model_aliases TYPE jsonb
USING provider_model_aliases::jsonb;
END IF;
END $$;
"""
)
# 创建 GIN 索引

View File

@@ -0,0 +1,47 @@
"""add first_byte_time_ms to usage table
Revision ID: 180e63a9c83a
Revises: e9b3d63f0cbf
Create Date: 2025-12-15 17:07:44.631032+00:00
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '180e63a9c83a'
down_revision = 'e9b3d63f0cbf'
branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def upgrade() -> None:
"""应用迁移:升级到新版本"""
bind = op.get_bind()
# 添加首字时间字段到 usage 表(如果不存在)
if not column_exists(bind, "usage", "first_byte_time_ms"):
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 删除首字时间字段
op.drop_column('usage', 'first_byte_time_ms')

View File

@@ -0,0 +1,110 @@
"""refactor global_model to use config json field
Revision ID: 1cc6942cf06f
Revises: 180e63a9c83a
Create Date: 2025-12-16 03:11:32.480976+00:00
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '1cc6942cf06f'
down_revision = '180e63a9c83a'
branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def upgrade() -> None:
"""应用迁移:升级到新版本
1. 添加 config 列
2. 把旧数据迁移到 config
3. 删除旧列
"""
bind = op.get_bind()
# 检查是否已经迁移过config 列存在且旧列不存在)
has_config = column_exists(bind, "global_models", "config")
has_old_columns = column_exists(bind, "global_models", "default_supports_streaming")
if has_config and not has_old_columns:
# 已完成迁移,跳过
return
# 1. 添加 config 列(使用 JSONB 类型,支持索引和更高效的查询)
if not has_config:
op.add_column('global_models', sa.Column('config', postgresql.JSONB(), nullable=True))
# 2. 迁移数据:把旧字段合并到 config JSON仅当旧列存在时
if has_old_columns:
op.execute("""
UPDATE global_models
SET config = jsonb_strip_nulls(jsonb_build_object(
'streaming', COALESCE(default_supports_streaming, true),
'vision', CASE WHEN COALESCE(default_supports_vision, false) THEN true ELSE NULL END,
'function_calling', CASE WHEN COALESCE(default_supports_function_calling, false) THEN true ELSE NULL END,
'extended_thinking', CASE WHEN COALESCE(default_supports_extended_thinking, false) THEN true ELSE NULL END,
'image_generation', CASE WHEN COALESCE(default_supports_image_generation, false) THEN true ELSE NULL END,
'description', description,
'icon_url', icon_url,
'official_url', official_url
))
""")
# 3. 删除旧列
op.drop_column('global_models', 'default_supports_streaming')
op.drop_column('global_models', 'default_supports_vision')
op.drop_column('global_models', 'default_supports_function_calling')
op.drop_column('global_models', 'default_supports_extended_thinking')
op.drop_column('global_models', 'default_supports_image_generation')
op.drop_column('global_models', 'description')
op.drop_column('global_models', 'icon_url')
op.drop_column('global_models', 'official_url')
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 1. 添加旧列
op.add_column('global_models', sa.Column('icon_url', sa.VARCHAR(length=500), nullable=True))
op.add_column('global_models', sa.Column('official_url', sa.VARCHAR(length=500), nullable=True))
op.add_column('global_models', sa.Column('description', sa.TEXT(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_streaming', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_vision', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_function_calling', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_extended_thinking', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_image_generation', sa.BOOLEAN(), nullable=True))
# 2. 从 config 恢复数据
op.execute("""
UPDATE global_models
SET
default_supports_streaming = COALESCE((config->>'streaming')::boolean, true),
default_supports_vision = COALESCE((config->>'vision')::boolean, false),
default_supports_function_calling = COALESCE((config->>'function_calling')::boolean, false),
default_supports_extended_thinking = COALESCE((config->>'extended_thinking')::boolean, false),
default_supports_image_generation = COALESCE((config->>'image_generation')::boolean, false),
description = config->>'description',
icon_url = config->>'icon_url',
official_url = config->>'official_url'
""")
# 3. 删除 config 列
op.drop_column('global_models', 'config')

View File

@@ -0,0 +1,57 @@
"""add proxy field to provider_endpoints
Revision ID: f30f9936f6a2
Revises: 1cc6942cf06f
Create Date: 2025-12-18 06:31:58.451112+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'f30f9936f6a2'
down_revision = '1cc6942cf06f'
branch_labels = None
depends_on = None
def column_exists(table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns
def get_column_type(table_name: str, column_name: str) -> str:
"""获取列的类型"""
bind = op.get_bind()
inspector = inspect(bind)
for col in inspector.get_columns(table_name):
if col['name'] == column_name:
return str(col['type']).upper()
return ''
def upgrade() -> None:
"""添加 proxy 字段到 provider_endpoints 表"""
if not column_exists('provider_endpoints', 'proxy'):
# 字段不存在,直接添加 JSONB 类型
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
else:
# 字段已存在,检查是否需要转换类型
col_type = get_column_type('provider_endpoints', 'proxy')
if 'JSONB' not in col_type:
# 如果是 JSON 类型,转换为 JSONB
op.execute(
'ALTER TABLE provider_endpoints '
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
)
def downgrade() -> None:
"""移除 proxy 字段"""
if column_exists('provider_endpoints', 'proxy'):
op.drop_column('provider_endpoints', 'proxy')

View File

@@ -21,9 +21,9 @@ HASH_FILE=".deps-hash"
CODE_HASH_FILE=".code-hash"
MIGRATION_HASH_FILE=".migration-hash"
# 计算依赖文件的哈希值
# 计算依赖文件的哈希值(包含 Dockerfile.base.local
calc_deps_hash() {
cat pyproject.toml frontend/package.json frontend/package-lock.json 2>/dev/null | md5sum | cut -d' ' -f1
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
}
# 计算代码文件的哈希值
@@ -88,7 +88,7 @@ build_base() {
# 构建应用镜像
build_app() {
echo ">>> Building app image (code only)..."
docker build -f Dockerfile.app -t aether-app:latest .
docker build -f Dockerfile.app.local -t aether-app:latest .
save_code_hash
}
@@ -162,25 +162,32 @@ git pull
# 标记是否需要重启
NEED_RESTART=false
BASE_REBUILT=false
# 检查基础镜像是否存在,或依赖是否变化
if ! docker image inspect aether-base:latest >/dev/null 2>&1; then
echo ">>> Base image not found, building..."
build_base
BASE_REBUILT=true
NEED_RESTART=true
elif check_deps_changed; then
echo ">>> Dependencies changed, rebuilding base image..."
build_base
BASE_REBUILT=true
NEED_RESTART=true
else
echo ">>> Dependencies unchanged."
fi
# 检查代码是否变化
# 检查代码是否变化,或者 base 重建了app 依赖 base
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
echo ">>> App image not found, building..."
build_app
NEED_RESTART=true
elif [ "$BASE_REBUILT" = true ]; then
echo ">>> Base image rebuilt, rebuilding app image..."
build_app
NEED_RESTART=true
elif check_code_changed; then
echo ">>> Code changed, rebuilding app image..."
build_app

View File

@@ -41,7 +41,7 @@ services:
app:
build:
context: .
dockerfile: Dockerfile.app
dockerfile: Dockerfile.app.local
image: aether-app:latest
container_name: aether-app
environment:

View File

@@ -1,5 +1,179 @@
import apiClient from './client'
// 配置导出数据结构
export interface ConfigExportData {
version: string
exported_at: string
global_models: GlobalModelExport[]
providers: ProviderExport[]
}
// 用户导出数据结构
export interface UsersExportData {
version: string
exported_at: string
users: UserExport[]
}
export interface UserExport {
email: string
username: string
password_hash: string
role: string
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_models?: string[] | null
model_capability_settings?: any
quota_usd?: number | null
used_usd?: number
total_usd?: number
is_active: boolean
api_keys: UserApiKeyExport[]
}
export interface UserApiKeyExport {
key_hash: string
key_encrypted?: string | null
name?: string | null
is_standalone: boolean
balance_used_usd?: number
current_balance_usd?: number | null
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
auto_delete_on_expiry?: boolean
total_requests?: number
total_cost_usd?: number
}
export interface GlobalModelExport {
name: string
display_name: string
default_price_per_request?: number | null
default_tiered_pricing: any
supported_capabilities?: string[] | null
config?: any
is_active: boolean
}
export interface ProviderExport {
name: string
display_name: string
description?: string | null
website?: string | null
billing_type?: string | null
monthly_quota_usd?: number | null
quota_reset_day?: number
rpm_limit?: number | null
provider_priority?: number
is_active: boolean
rate_limit?: number | null
concurrent_limit?: number | null
config?: any
endpoints: EndpointExport[]
models: ModelExport[]
}
export interface EndpointExport {
api_format: string
base_url: string
headers?: any
timeout?: number
max_retries?: number
max_concurrent?: number | null
rate_limit?: number | null
is_active: boolean
custom_path?: string | null
config?: any
keys: KeyExport[]
}
export interface KeyExport {
api_key: string
name?: string | null
note?: string | null
rate_multiplier?: number
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null
rate_limit?: number | null
daily_limit?: number | null
monthly_limit?: number | null
allowed_models?: string[] | null
capabilities?: any
is_active: boolean
}
export interface ModelExport {
global_model_name: string | null
provider_model_name: string
provider_model_aliases?: any
price_per_request?: number | null
tiered_pricing?: any
supports_vision?: boolean | null
supports_function_calling?: boolean | null
supports_streaming?: boolean | null
supports_extended_thinking?: boolean | null
supports_image_generation?: boolean | null
is_active: boolean
config?: any
}
// Provider 模型查询响应
export interface ProviderModelsQueryResponse {
success: boolean
data: {
models: Array<{
id: string
object?: string
created?: number
owned_by?: string
display_name?: string
api_format?: string
}>
error?: string
}
provider: {
id: string
name: string
display_name: string
}
}
export interface ConfigImportRequest extends ConfigExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportRequest extends UsersExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportResponse {
message: string
stats: {
users: { created: number; updated: number; skipped: number }
api_keys: { created: number; skipped: number }
errors: string[]
}
}
export interface ConfigImportResponse {
message: string
stats: {
global_models: { created: number; updated: number; skipped: number }
providers: { created: number; updated: number; skipped: number }
endpoints: { created: number; updated: number; skipped: number }
keys: { created: number; updated: number; skipped: number }
models: { created: number; updated: number; skipped: number }
errors: string[]
}
}
// API密钥管理相关接口定义
export interface AdminApiKey {
id: string // UUID
@@ -173,5 +347,44 @@ export const adminApi = {
'/api/admin/system/api-formats'
)
return response.data
},
// 导出配置
async exportConfig(): Promise<ConfigExportData> {
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
return response.data
},
// 导入配置
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
const response = await apiClient.post<ConfigImportResponse>(
'/api/admin/system/config/import',
data
)
return response.data
},
// 导出用户数据
async exportUsers(): Promise<UsersExportData> {
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
return response.data
},
// 导入用户数据
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
const response = await apiClient.post<UsersImportResponse>(
'/api/admin/system/users/import',
data
)
return response.data
},
// 查询 Provider 可用模型(从上游 API 获取)
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
const response = await apiClient.post<ProviderModelsQueryResponse>(
'/api/admin/provider-query/models',
{ provider_id: providerId, api_key_id: apiKeyId }
)
return response.data
}
}

View File

@@ -66,6 +66,7 @@ export interface UserAffinity {
key_name: string | null
key_prefix: string | null // Provider Key 脱敏显示前4...后4
rate_multiplier: number
global_model_id: string | null // 原始的 global_model_id用于删除
model_name: string | null // 模型名称(如 claude-haiku-4-5-20250514
model_display_name: string | null // 模型显示名称(如 Claude Haiku 4.5
api_format: string | null // API 格式 (claude/openai)
@@ -119,6 +120,18 @@ export const cacheApi = {
await api.delete(`/api/admin/monitoring/cache/users/${userIdentifier}`)
},
/**
* 清除单条缓存亲和性
*
* @param affinityKey API Key ID
* @param endpointId Endpoint ID
* @param modelId GlobalModel ID
* @param apiFormat API 格式 (claude/openai)
*/
async clearSingleAffinity(affinityKey: string, endpointId: string, modelId: string, apiFormat: string): Promise<void> {
await api.delete(`/api/admin/monitoring/cache/affinity/${affinityKey}/${endpointId}/${modelId}/${apiFormat}`)
},
/**
* 清除所有缓存
*/

View File

@@ -1,5 +1,5 @@
import client from '../client'
import type { ProviderEndpoint } from './types'
import type { ProviderEndpoint, ProxyConfig } from './types'
/**
* 获取指定 Provider 的所有 Endpoints
@@ -38,6 +38,7 @@ export async function createEndpoint(
rate_limit?: number
is_active?: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
}
): Promise<ProviderEndpoint> {
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
@@ -63,6 +64,7 @@ export async function updateEndpoint(
rate_limit: number
is_active: boolean
config: Record<string, any>
proxy: ProxyConfig | null
}>
): Promise<ProviderEndpoint> {
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)

View File

@@ -1,3 +1,35 @@
// API 格式常量
export const API_FORMATS = {
CLAUDE: 'CLAUDE',
CLAUDE_CLI: 'CLAUDE_CLI',
OPENAI: 'OPENAI',
OPENAI_CLI: 'OPENAI_CLI',
GEMINI: 'GEMINI',
GEMINI_CLI: 'GEMINI_CLI',
} as const
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
// API 格式显示名称映射按品牌分组API 在前CLI 在后)
export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.CLAUDE]: 'Claude',
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
[API_FORMATS.OPENAI]: 'OpenAI',
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
[API_FORMATS.GEMINI]: 'Gemini',
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
}
/**
* 代理配置类型
*/
export interface ProxyConfig {
url: string
username?: string
password?: string
enabled?: boolean // 是否启用代理false 时保留配置但不使用)
}
export interface ProviderEndpoint {
id: string
provider_id: string
@@ -19,6 +51,7 @@ export interface ProviderEndpoint {
last_failure_at?: string
is_active: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
total_keys: number
active_keys: number
created_at: string
@@ -214,6 +247,7 @@ export interface ConcurrencyStatus {
export interface ProviderModelAlias {
name: string
priority: number // 优先级(数字越小优先级越高)
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
}
export interface Model {
@@ -407,67 +441,45 @@ export interface TieredPricingConfig {
export interface GlobalModelCreate {
name: string
display_name: string
description?: string
official_url?: string
icon_url?: string
// 按次计费配置(可选,与阶梯计费叠加)
default_price_per_request?: number
// 阶梯计费配置(必填,固定价格用单阶梯表示)
default_tiered_pricing: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[]
// 模型配置JSON格式- 包含能力、规格、元信息等
config?: Record<string, any>
is_active?: boolean
}
export interface GlobalModelUpdate {
display_name?: string
description?: string
official_url?: string
icon_url?: string
is_active?: boolean
// 按次计费配置
default_price_per_request?: number | null // null 表示清空
// 阶梯计费配置
default_tiered_pricing?: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[] | null
// 模型配置JSON格式- 包含能力、规格、元信息等
config?: Record<string, any> | null
}
export interface GlobalModelResponse {
id: string
name: string
display_name: string
description?: string
official_url?: string
icon_url?: string
is_active: boolean
// 按次计费配置
default_price_per_request?: number
// 阶梯计费配置(必填)
default_tiered_pricing: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[] | null
// 模型配置JSON格式
config?: Record<string, any> | null
// 统计数据
provider_count?: number
alias_count?: number
usage_count?: number
created_at: string
updated_at?: string

View File

@@ -0,0 +1,288 @@
/**
* Models.dev API 服务
* 通过后端代理获取 models.dev 数据(解决跨域问题)
*/
import api from './client'
// 缓存配置
const CACHE_KEY = 'models_dev_cache'
const CACHE_DURATION = 15 * 60 * 1000 // 15 分钟
// Models.dev API 数据结构
export interface ModelsDevCost {
input?: number
output?: number
reasoning?: number
cache_read?: number
}
export interface ModelsDevLimit {
context?: number
output?: number
}
export interface ModelsDevModel {
id: string
name: string
family?: string
reasoning?: boolean
tool_call?: boolean
structured_output?: boolean
temperature?: boolean
attachment?: boolean
knowledge?: string
release_date?: string
last_updated?: string
input?: string[] // 输入模态: text, image, audio, video, pdf
output?: string[] // 输出模态: text, image, audio
open_weights?: boolean
cost?: ModelsDevCost
limit?: ModelsDevLimit
deprecated?: boolean
}
export interface ModelsDevProvider {
id: string
env?: string[]
npm?: string
api?: string
name: string
doc?: string
models: Record<string, ModelsDevModel>
official?: boolean // 是否为官方提供商
}
export type ModelsDevData = Record<string, ModelsDevProvider>
// 扁平化的模型列表项(用于搜索和选择)
export interface ModelsDevModelItem {
providerId: string
providerName: string
modelId: string
modelName: string
family?: string
inputPrice?: number
outputPrice?: number
contextLimit?: number
outputLimit?: number
supportsVision?: boolean
supportsToolCall?: boolean
supportsReasoning?: boolean
supportsStructuredOutput?: boolean
supportsTemperature?: boolean
supportsAttachment?: boolean
openWeights?: boolean
deprecated?: boolean
official?: boolean // 是否来自官方提供商
// 用于 display_metadata 的额外字段
knowledgeCutoff?: string
releaseDate?: string
inputModalities?: string[]
outputModalities?: string[]
}
interface CacheData {
timestamp: number
data: ModelsDevData
}
// 内存缓存
let memoryCache: CacheData | null = null
function hasOfficialFlag(data: ModelsDevData): boolean {
return Object.values(data).some(provider => typeof provider?.official === 'boolean')
}
/**
* 获取 models.dev 数据(带缓存)
*/
export async function getModelsDevData(): Promise<ModelsDevData> {
// 1. 检查内存缓存
if (memoryCache && Date.now() - memoryCache.timestamp < CACHE_DURATION) {
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
if (hasOfficialFlag(memoryCache.data)) {
return memoryCache.data
}
memoryCache = null
}
// 2. 检查 localStorage 缓存
try {
const cached = localStorage.getItem(CACHE_KEY)
if (cached) {
const cacheData: CacheData = JSON.parse(cached)
if (Date.now() - cacheData.timestamp < CACHE_DURATION) {
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
if (hasOfficialFlag(cacheData.data)) {
memoryCache = cacheData
return cacheData.data
}
localStorage.removeItem(CACHE_KEY)
}
}
} catch {
// 缓存解析失败,忽略
}
// 3. 从后端代理获取新数据
const response = await api.get<ModelsDevData>('/api/admin/models/external')
const data = response.data
// 4. 更新缓存
const cacheData: CacheData = {
timestamp: Date.now(),
data,
}
memoryCache = cacheData
try {
localStorage.setItem(CACHE_KEY, JSON.stringify(cacheData))
} catch {
// localStorage 写入失败,忽略
}
return data
}
// 模型列表缓存(避免重复转换)
let modelsListCache: ModelsDevModelItem[] | null = null
let modelsListCacheTimestamp: number | null = null
/**
* 获取扁平化的模型列表
* 数据只加载一次,通过参数过滤官方/全部
*/
export async function getModelsDevList(officialOnly: boolean = true): Promise<ModelsDevModelItem[]> {
const data = await getModelsDevData()
const currentTimestamp = memoryCache?.timestamp ?? 0
// 如果缓存为空或数据已刷新,构建一次
if (!modelsListCache || modelsListCacheTimestamp !== currentTimestamp) {
const items: ModelsDevModelItem[] = []
for (const [providerId, provider] of Object.entries(data)) {
if (!provider.models) continue
for (const [modelId, model] of Object.entries(provider.models)) {
items.push({
providerId,
providerName: provider.name,
modelId,
modelName: model.name || modelId,
family: model.family,
inputPrice: model.cost?.input,
outputPrice: model.cost?.output,
contextLimit: model.limit?.context,
outputLimit: model.limit?.output,
supportsVision: model.input?.includes('image'),
supportsToolCall: model.tool_call,
supportsReasoning: model.reasoning,
supportsStructuredOutput: model.structured_output,
supportsTemperature: model.temperature,
supportsAttachment: model.attachment,
openWeights: model.open_weights,
deprecated: model.deprecated,
official: provider.official,
// display_metadata 相关字段
knowledgeCutoff: model.knowledge,
releaseDate: model.release_date,
inputModalities: model.input,
outputModalities: model.output,
})
}
}
// 按 provider 名称和模型名称排序
items.sort((a, b) => {
const providerCompare = a.providerName.localeCompare(b.providerName)
if (providerCompare !== 0) return providerCompare
return a.modelName.localeCompare(b.modelName)
})
modelsListCache = items
modelsListCacheTimestamp = currentTimestamp
}
// 根据参数过滤
if (officialOnly) {
return modelsListCache.filter(m => m.official)
}
return modelsListCache
}
/**
* 搜索模型
* 搜索时包含所有提供商(包括第三方)
*/
export async function searchModelsDevModels(
query: string,
options?: {
limit?: number
excludeDeprecated?: boolean
}
): Promise<ModelsDevModelItem[]> {
// 搜索时包含全部提供商
const allModels = await getModelsDevList(false)
const { limit = 50, excludeDeprecated = true } = options || {}
const queryLower = query.toLowerCase()
const filtered = allModels.filter((model) => {
if (excludeDeprecated && model.deprecated) return false
// 搜索模型 ID、名称、provider 名称、family
return (
model.modelId.toLowerCase().includes(queryLower) ||
model.modelName.toLowerCase().includes(queryLower) ||
model.providerName.toLowerCase().includes(queryLower) ||
model.family?.toLowerCase().includes(queryLower)
)
})
// 排序:精确匹配优先
filtered.sort((a, b) => {
const aExact =
a.modelId.toLowerCase() === queryLower ||
a.modelName.toLowerCase() === queryLower
const bExact =
b.modelId.toLowerCase() === queryLower ||
b.modelName.toLowerCase() === queryLower
if (aExact && !bExact) return -1
if (!aExact && bExact) return 1
return 0
})
return filtered.slice(0, limit)
}
/**
* 获取特定模型详情
*/
export async function getModelsDevModel(
providerId: string,
modelId: string
): Promise<ModelsDevModel | null> {
const data = await getModelsDevData()
return data[providerId]?.models?.[modelId] || null
}
/**
* 获取 provider logo URL
*/
export function getProviderLogoUrl(providerId: string): string {
return `https://models.dev/logos/${providerId}.svg`
}
/**
* 清除缓存
*/
export function clearModelsDevCache(): void {
memoryCache = null
modelsListCache = null
modelsListCacheTimestamp = null
try {
localStorage.removeItem(CACHE_KEY)
} catch {
// 忽略错误
}
}

View File

@@ -9,20 +9,14 @@ export interface PublicGlobalModel {
id: string
name: string
display_name: string | null
description: string | null
icon_url: string | null
is_active: boolean
// 阶梯计费配置
default_tiered_pricing: TieredPricingConfig
default_price_per_request: number | null // 按次计费价格
// 能力
default_supports_vision: boolean
default_supports_function_calling: boolean
default_supports_streaming: boolean
default_supports_extended_thinking: boolean
default_supports_image_generation: boolean
// Key 能力支持
supported_capabilities: string[] | null
// 模型配置JSON
config: Record<string, any> | null
}
export interface PublicGlobalModelListResponse {

View File

@@ -299,7 +299,7 @@ function formatDuration(ms: number): string {
const hours = Math.floor(ms / (1000 * 60 * 60))
const minutes = Math.floor((ms % (1000 * 60 * 60)) / (1000 * 60))
if (hours > 0) {
return `${hours}h${minutes > 0 ? minutes + 'm' : ''}`
return `${hours}h${minutes > 0 ? `${minutes}m` : ''}`
}
return `${minutes}m`
}

View File

@@ -34,11 +34,10 @@ const buttonClass = computed(() => {
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
const variantClasses = {
default:
'bg-primary text-white shadow-[0_20px_35px_rgba(204,120,92,0.35)] hover:bg-primary/90 hover:shadow-[0_25px_45px_rgba(204,120,92,0.45)]',
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85 shadow-sm',
default: 'bg-primary text-white hover:bg-primary/90',
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85',
outline:
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 shadow-sm backdrop-blur transition-all',
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 backdrop-blur transition-all',
secondary:
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
ghost: 'hover:bg-accent hover:text-accent-foreground',

View File

@@ -2,7 +2,7 @@
<Teleport to="body">
<div
v-if="isOpen"
class="fixed inset-0 overflow-y-auto"
class="fixed inset-0 overflow-y-auto pointer-events-none"
:style="{ zIndex: containerZIndex }"
>
<!-- 背景遮罩 -->
@@ -16,13 +16,13 @@
>
<div
v-if="isOpen"
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity"
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
:style="{ zIndex: backdropZIndex }"
@click="handleClose"
/>
</Transition>
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
<!-- 对话框内容 -->
<Transition
enter-active-class="duration-300 ease-out"
@@ -34,7 +34,7 @@
>
<div
v-if="isOpen"
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border"
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border pointer-events-auto"
:style="{ zIndex: contentZIndex }"
:class="maxWidthClass"
@click.stop

View File

@@ -45,7 +45,7 @@ const props = withDefaults(defineProps<Props>(), {
const contentClass = computed(() =>
cn(
'z-[100] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
'z-[200] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
props.class

View File

@@ -132,7 +132,7 @@
type="number"
min="1"
max="10000"
placeholder="100"
placeholder="留空不限制"
class="h-10"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
/>
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -389,7 +389,7 @@ function resetForm() {
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -408,7 +408,7 @@ function loadKeyData() {
initial_balance_usd: props.apiKey.initial_balance_usd,
expire_days: props.apiKey.expire_days,
never_expire: props.apiKey.never_expire,
rate_limit: props.apiKey.rate_limit || 100,
rate_limit: props.apiKey.rate_limit,
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
allowed_providers: props.apiKey.allowed_providers || [],
allowed_api_formats: props.apiKey.allowed_api_formats || [],

View File

@@ -2,174 +2,304 @@
<Dialog
:model-value="open"
:title="isEditMode ? '编辑模型' : '创建统一模型'"
:description="isEditMode ? '修改模型配置和价格信息' : '添加一个新的全局模型定义'"
:description="isEditMode ? '修改模型配置和价格信息' : ''"
:icon="isEditMode ? SquarePen : Layers"
size="xl"
size="3xl"
@update:model-value="handleDialogUpdate"
>
<form
class="space-y-5 max-h-[70vh] overflow-y-auto pr-1"
@submit.prevent="handleSubmit"
<div
class="flex gap-4"
:class="isEditMode ? '' : 'h-[500px]'"
>
<!-- 基本信息 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
基本信息
</h4>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label
for="model-name"
class="text-xs"
>模型名称 *</Label>
<Input
id="model-name"
v-model="form.name"
placeholder="claude-3-5-sonnet-20241022"
:disabled="isEditMode"
required
/>
<p
v-if="!isEditMode"
class="text-xs text-muted-foreground"
>
创建后不可修改
</p>
</div>
<div class="space-y-1.5">
<Label
for="model-display-name"
class="text-xs"
>显示名称 *</Label>
<Input
id="model-display-name"
v-model="form.display_name"
placeholder="Claude 3.5 Sonnet"
required
/>
</div>
</div>
<div class="space-y-1.5">
<Label
for="model-description"
class="text-xs"
>描述</Label>
<Input
id="model-description"
v-model="form.description"
placeholder="简短描述此模型的特点"
/>
</div>
</section>
<!-- 能力配置 -->
<section class="space-y-2">
<h4 class="font-medium text-sm">
默认能力
</h4>
<div class="flex flex-wrap gap-2">
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_streaming"
type="checkbox"
class="rounded"
>
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
<span>流式输出</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_vision"
type="checkbox"
class="rounded"
>
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
<span>视觉理解</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_function_calling"
type="checkbox"
class="rounded"
>
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
<span>工具调用</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_extended_thinking"
type="checkbox"
class="rounded"
>
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
<span>深度思考</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_image_generation"
type="checkbox"
class="rounded"
>
<Image class="w-3.5 h-3.5 text-muted-foreground" />
<span>图像生成</span>
</label>
</div>
</section>
<!-- Key 能力配置 -->
<section
v-if="availableCapabilities.length > 0"
class="space-y-2"
<!-- 左侧模型选择仅创建模式 -->
<div
v-if="!isEditMode"
class="w-[260px] shrink-0 flex flex-col h-full"
>
<h4 class="font-medium text-sm">
Key 能力支持
</h4>
<div class="flex flex-wrap gap-2">
<label
v-for="cap in availableCapabilities"
:key="cap.name"
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
>
<input
type="checkbox"
:checked="form.supported_capabilities?.includes(cap.name)"
class="rounded"
@change="toggleCapability(cap.name)"
>
<span>{{ cap.display_name }}</span>
</label>
</div>
</section>
<!-- 价格配置 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
价格配置
</h4>
<TieredPricingEditor
ref="tieredPricingEditorRef"
v-model="tieredPricing"
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
/>
<!-- 按次计费 -->
<div class="flex items-center gap-3 pt-2 border-t">
<Label class="text-xs whitespace-nowrap">按次计费 ($/)</Label>
<!-- 搜索框 -->
<div class="relative mb-3">
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
:model-value="form.default_price_per_request ?? ''"
type="number"
step="0.001"
min="0"
class="w-32"
placeholder="留空不启用"
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
v-model="searchQuery"
type="text"
placeholder="搜索模型、提供商..."
class="pl-8 h-8 text-sm"
/>
<span class="text-xs text-muted-foreground">每次请求固定费用,可与 Token 计费叠加</span>
</div>
</section>
</form>
<!-- 模型列表两级结构 -->
<div class="flex-1 overflow-y-auto border rounded-lg min-h-0 scrollbar-thin">
<div
v-if="loading"
class="flex items-center justify-center h-32"
>
<Loader2 class="w-5 h-5 animate-spin text-muted-foreground" />
</div>
<template v-else>
<!-- 提供商分组 -->
<div
v-for="group in groupedModels"
:key="group.providerId"
class="border-b last:border-b-0"
>
<!-- 提供商标题行 -->
<div
class="flex items-center gap-2 px-2.5 py-2 cursor-pointer hover:bg-muted text-sm"
@click="toggleProvider(group.providerId)"
>
<ChevronRight
class="w-3.5 h-3.5 text-muted-foreground transition-transform shrink-0"
:class="expandedProvider === group.providerId ? 'rotate-90' : ''"
/>
<img
:src="getProviderLogoUrl(group.providerId)"
:alt="group.providerName"
class="w-4 h-4 rounded shrink-0 dark:invert dark:brightness-90"
@error="handleLogoError"
>
<span class="truncate font-medium text-xs flex-1">{{ group.providerName }}</span>
<span class="text-[10px] text-muted-foreground shrink-0">{{ group.models.length }}</span>
</div>
<!-- 模型列表 -->
<div
v-if="expandedProvider === group.providerId"
class="bg-muted/30"
>
<div
v-for="model in group.models"
:key="model.modelId"
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'bg-primary text-primary-foreground'
: 'hover:bg-muted'"
@click="selectModel(model)"
>
<span class="truncate font-medium">{{ model.modelName }}</span>
<span
class="truncate text-[10px]"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'text-primary-foreground/70'
: 'text-muted-foreground'"
>{{ model.modelId }}</span>
</div>
</div>
</div>
<div
v-if="groupedModels.length === 0"
class="text-center py-8 text-sm text-muted-foreground"
>
{{ searchQuery ? '未找到模型' : '加载中...' }}
</div>
</template>
</div>
</div>
<!-- 右侧表单 -->
<div
class="flex-1 overflow-y-auto h-full scrollbar-thin"
:class="isEditMode ? 'max-h-[70vh]' : ''"
>
<form
class="space-y-5"
@submit.prevent="handleSubmit"
>
<!-- 基本信息 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
基本信息
</h4>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label
for="model-name"
class="text-xs"
>模型名称 *</Label>
<Input
id="model-name"
v-model="form.name"
placeholder="claude-3-5-sonnet-20241022"
:disabled="isEditMode"
required
/>
</div>
<div class="space-y-1.5">
<Label
for="model-display-name"
class="text-xs"
>显示名称 *</Label>
<Input
id="model-display-name"
v-model="form.display_name"
placeholder="Claude 3.5 Sonnet"
required
/>
</div>
</div>
<div class="space-y-1.5">
<Label
for="model-description"
class="text-xs"
>描述</Label>
<Input
id="model-description"
:model-value="form.config?.description || ''"
placeholder="简短描述此模型的特点"
@update:model-value="(v) => setConfigField('description', v || undefined)"
/>
</div>
<div class="grid grid-cols-3 gap-3">
<div class="space-y-1.5">
<Label
for="model-family"
class="text-xs"
>模型系列</Label>
<Input
id="model-family"
:model-value="form.config?.family || ''"
placeholder=" GPT-4Claude 3"
@update:model-value="(v) => setConfigField('family', v || undefined)"
/>
</div>
<div class="space-y-1.5">
<Label
for="model-context-limit"
class="text-xs"
>上下文限制</Label>
<Input
id="model-context-limit"
type="number"
:model-value="form.config?.context_limit ?? ''"
placeholder=" 128000"
@update:model-value="(v) => setConfigField('context_limit', v ? Number(v) : undefined)"
/>
</div>
<div class="space-y-1.5">
<Label
for="model-output-limit"
class="text-xs"
>输出限制</Label>
<Input
id="model-output-limit"
type="number"
:model-value="form.config?.output_limit ?? ''"
placeholder=" 8192"
@update:model-value="(v) => setConfigField('output_limit', v ? Number(v) : undefined)"
/>
</div>
</div>
</section>
<!-- 能力配置 -->
<section class="space-y-2">
<h4 class="font-medium text-sm">
默认能力
</h4>
<div class="flex flex-wrap gap-2">
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.streaming !== false"
class="rounded"
@change="setConfigField('streaming', ($event.target as HTMLInputElement).checked)"
>
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
<span>流式</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.vision === true"
class="rounded"
@change="setConfigField('vision', ($event.target as HTMLInputElement).checked)"
>
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
<span>视觉</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.function_calling === true"
class="rounded"
@change="setConfigField('function_calling', ($event.target as HTMLInputElement).checked)"
>
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
<span>工具</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.extended_thinking === true"
class="rounded"
@change="setConfigField('extended_thinking', ($event.target as HTMLInputElement).checked)"
>
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
<span>思考</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.image_generation === true"
class="rounded"
@change="setConfigField('image_generation', ($event.target as HTMLInputElement).checked)"
>
<Image class="w-3.5 h-3.5 text-muted-foreground" />
<span>生图</span>
</label>
</div>
</section>
<!-- Key 能力配置 -->
<section
v-if="availableCapabilities.length > 0"
class="space-y-2"
>
<h4 class="font-medium text-sm">
Key 能力支持
</h4>
<div class="flex flex-wrap gap-2">
<label
v-for="cap in availableCapabilities"
:key="cap.name"
class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm"
>
<input
type="checkbox"
:checked="form.supported_capabilities?.includes(cap.name)"
class="rounded"
@change="toggleCapability(cap.name)"
>
<span>{{ cap.display_name }}</span>
</label>
</div>
</section>
<!-- 价格配置 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
价格配置
</h4>
<TieredPricingEditor
ref="tieredPricingEditorRef"
v-model="tieredPricing"
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
/>
<div class="flex items-center gap-3 pt-2 border-t">
<Label class="text-xs whitespace-nowrap">按次计费</Label>
<Input
:model-value="form.default_price_per_request ?? ''"
type="number"
step="0.001"
min="0"
class="w-24"
placeholder="$/"
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
/>
<span class="text-xs text-muted-foreground">可与 Token 计费叠加</span>
</div>
</section>
</form>
</div>
</div>
<template #footer>
<Button
@@ -180,7 +310,7 @@
取消
</Button>
<Button
:disabled="submitting"
:disabled="submitting || !form.name || !form.display_name"
@click="handleSubmit"
>
<Loader2
@@ -189,19 +319,35 @@
/>
{{ isEditMode ? '保存' : '创建' }}
</Button>
<Button
v-if="selectedModel && !isEditMode"
type="button"
variant="ghost"
@click="clearSelection"
>
清空
</Button>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen } from 'lucide-vue-next'
import { ref, computed, onMounted, watch } from 'vue'
import {
Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen,
Search, ChevronRight
} from 'lucide-vue-next'
import { Dialog, Button, Input, Label } from '@/components/ui'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
import { parseNumberInput } from '@/utils/form'
import { log } from '@/utils/logger'
import TieredPricingEditor from './TieredPricingEditor.vue'
import {
getModelsDevList,
getProviderLogoUrl,
type ModelsDevModelItem,
} from '@/api/models-dev'
import {
createGlobalModel,
updateGlobalModel,
@@ -226,42 +372,147 @@ const { success, error: showError } = useToast()
const submitting = ref(false)
const tieredPricingEditorRef = ref<InstanceType<typeof TieredPricingEditor> | null>(null)
// 阶梯计费配置(统一使用,固定价格就是单阶梯)
// 模型列表相关
const loading = ref(false)
const searchQuery = ref('')
const allModelsCache = ref<ModelsDevModelItem[]>([]) // 全部模型(缓存)
const selectedModel = ref<ModelsDevModelItem | null>(null)
const expandedProvider = ref<string | null>(null)
// 当前显示的模型列表:有搜索词时用全部,否则只用官方
const allModels = computed(() => {
if (searchQuery.value) {
return allModelsCache.value
}
return allModelsCache.value.filter(m => m.official)
})
// 按提供商分组的模型
interface ProviderGroup {
providerId: string
providerName: string
models: ModelsDevModelItem[]
}
const groupedModels = computed(() => {
let models = allModels.value.filter(m => !m.deprecated)
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
models = models.filter(model => {
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 按提供商分组
const groups = new Map<string, ProviderGroup>()
for (const model of models) {
if (!groups.has(model.providerId)) {
groups.set(model.providerId, {
providerId: model.providerId,
providerName: model.providerName,
models: []
})
}
groups.get(model.providerId)!.models.push(model)
}
// 转换为数组并排序
const result = Array.from(groups.values())
// 如果有搜索词,把提供商名称/ID匹配的排在前面
if (searchQuery.value) {
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result.sort((a, b) => {
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
const aProviderMatch = keywords.some(k => aText.includes(k))
const bProviderMatch = keywords.some(k => bText.includes(k))
if (aProviderMatch && !bProviderMatch) return -1
if (!aProviderMatch && bProviderMatch) return 1
return a.providerName.localeCompare(b.providerName)
})
} else {
result.sort((a, b) => a.providerName.localeCompare(b.providerName))
}
return result
})
// 搜索时如果只有一个提供商,自动展开
watch(groupedModels, (groups) => {
if (searchQuery.value && groups.length === 1) {
expandedProvider.value = groups[0].providerId
}
})
// 切换提供商展开状态
function toggleProvider(providerId: string) {
expandedProvider.value = expandedProvider.value === providerId ? null : providerId
}
// 阶梯计费配置
const tieredPricing = ref<TieredPricingConfig | null>(null)
interface FormData {
name: string
display_name: string
description?: string
default_price_per_request?: number
default_supports_streaming?: boolean
default_supports_image_generation?: boolean
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_extended_thinking?: boolean
supported_capabilities?: string[]
config?: Record<string, any>
is_active?: boolean
}
const defaultForm = (): FormData => ({
name: '',
display_name: '',
description: '',
default_price_per_request: undefined,
default_supports_streaming: true,
default_supports_image_generation: false,
default_supports_vision: false,
default_supports_function_calling: false,
default_supports_extended_thinking: false,
supported_capabilities: [],
config: { streaming: true },
is_active: true,
})
const form = ref<FormData>(defaultForm())
const KEEP_FALSE_CONFIG_KEYS = new Set(['streaming'])
// 设置 config 字段
function setConfigField(key: string, value: any) {
if (!form.value.config) {
form.value.config = {}
}
if (value === undefined || value === '' || (value === false && !KEEP_FALSE_CONFIG_KEYS.has(key))) {
delete form.value.config[key]
} else {
form.value.config[key] = value
}
}
// Key 能力选项
const availableCapabilities = ref<CapabilityDefinition[]>([])
// 加载模型列表
async function loadModels() {
if (allModelsCache.value.length > 0) return
loading.value = true
try {
// 只加载一次全部模型,过滤在 computed 中完成
allModelsCache.value = await getModelsDevList(false)
} catch (err) {
log.error('Failed to load models:', err)
} finally {
loading.value = false
}
}
// 打开对话框时加载数据
watch(() => props.open, (isOpen) => {
if (isOpen && !props.model) {
loadModels()
}
})
// 加载可用能力列表
async function loadCapabilities() {
try {
@@ -284,38 +535,92 @@ function toggleCapability(capName: string) {
}
}
// 组件挂载时加载能力列表
onMounted(() => {
loadCapabilities()
})
// 选择模型并填充表单
function selectModel(model: ModelsDevModelItem) {
selectedModel.value = model
expandedProvider.value = model.providerId
form.value.name = model.modelId
form.value.display_name = model.modelName
// 构建 config
const config: Record<string, any> = {
streaming: true,
}
if (model.supportsVision) config.vision = true
if (model.supportsToolCall) config.function_calling = true
if (model.supportsReasoning) config.extended_thinking = true
if (model.supportsStructuredOutput) config.structured_output = true
if (model.supportsTemperature !== false) config.temperature = model.supportsTemperature
if (model.supportsAttachment) config.attachment = true
if (model.openWeights) config.open_weights = true
if (model.contextLimit) config.context_limit = model.contextLimit
if (model.outputLimit) config.output_limit = model.outputLimit
if (model.knowledgeCutoff) config.knowledge_cutoff = model.knowledgeCutoff
if (model.family) config.family = model.family
if (model.releaseDate) config.release_date = model.releaseDate
if (model.inputModalities?.length) config.input_modalities = model.inputModalities
if (model.outputModalities?.length) config.output_modalities = model.outputModalities
form.value.config = config
if (model.inputPrice !== undefined || model.outputPrice !== undefined) {
tieredPricing.value = {
tiers: [{
up_to: null,
input_price_per_1m: model.inputPrice || 0,
output_price_per_1m: model.outputPrice || 0,
}]
}
} else {
tieredPricing.value = null
}
}
// 清除选择(手动填写)
function clearSelection() {
selectedModel.value = null
form.value = defaultForm()
tieredPricing.value = null
}
// Logo 加载失败处理
function handleLogoError(event: Event) {
const img = event.target as HTMLImageElement
img.style.display = 'none'
}
// 重置表单
function resetForm() {
form.value = defaultForm()
tieredPricing.value = null
searchQuery.value = ''
selectedModel.value = null
expandedProvider.value = null
}
// 加载模型数据(编辑模式)
function loadModelData() {
if (!props.model) return
// 先重置创建模式的残留状态
selectedModel.value = null
searchQuery.value = ''
expandedProvider.value = null
form.value = {
name: props.model.name,
display_name: props.model.display_name,
description: props.model.description,
default_price_per_request: props.model.default_price_per_request,
default_supports_streaming: props.model.default_supports_streaming,
default_supports_image_generation: props.model.default_supports_image_generation,
default_supports_vision: props.model.default_supports_vision,
default_supports_function_calling: props.model.default_supports_function_calling,
default_supports_extended_thinking: props.model.default_supports_extended_thinking,
supported_capabilities: [...(props.model.supported_capabilities || [])],
config: props.model.config ? { ...props.model.config } : { streaming: true },
is_active: props.model.is_active,
}
// 加载阶梯计费配置(深拷贝)
if (props.model.default_tiered_pricing) {
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
}
// 确保 tieredPricing 也被正确设置或重置
tieredPricing.value = props.model.default_tiered_pricing
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
: null
}
// 使用 useFormDialog 统一处理对话框逻辑
@@ -339,24 +644,22 @@ async function handleSubmit() {
return
}
// 获取包含自动计算缓存价格的最终数据
const finalTiers = tieredPricingEditorRef.value?.getFinalTiers()
const finalTieredPricing = finalTiers ? { tiers: finalTiers } : tieredPricing.value
// 清理空的 config
const cleanConfig = form.value.config && Object.keys(form.value.config).length > 0
? form.value.config
: undefined
submitting.value = true
try {
if (isEditMode.value && props.model) {
const updateData: GlobalModelUpdate = {
display_name: form.value.display_name,
description: form.value.description,
// 使用 null 而不是 undefined 来显式清空字段
config: cleanConfig || null,
default_price_per_request: form.value.default_price_per_request ?? null,
default_tiered_pricing: finalTieredPricing,
default_supports_streaming: form.value.default_supports_streaming,
default_supports_image_generation: form.value.default_supports_image_generation,
default_supports_vision: form.value.default_supports_vision,
default_supports_function_calling: form.value.default_supports_function_calling,
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : null,
is_active: form.value.is_active,
}
@@ -366,14 +669,9 @@ async function handleSubmit() {
const createData: GlobalModelCreate = {
name: form.value.name!,
display_name: form.value.display_name!,
description: form.value.description,
default_price_per_request: form.value.default_price_per_request || undefined,
config: cleanConfig,
default_price_per_request: form.value.default_price_per_request ?? undefined,
default_tiered_pricing: finalTieredPricing,
default_supports_streaming: form.value.default_supports_streaming,
default_supports_image_generation: form.value.default_supports_image_generation,
default_supports_vision: form.value.default_supports_vision,
default_supports_function_calling: form.value.default_supports_function_calling,
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : undefined,
is_active: form.value.is_active,
}

View File

@@ -38,12 +38,12 @@
>
<Copy class="w-3 h-3" />
</button>
<template v-if="model.description">
<template v-if="model.config?.description">
<span class="shrink-0">·</span>
<span
class="text-xs truncate"
:title="model.description"
>{{ model.description }}</span>
:title="model.config?.description"
>{{ model.config?.description }}</span>
</template>
</div>
</div>
@@ -143,10 +143,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -160,10 +160,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -177,10 +177,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
:variant="model.config?.vision === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
{{ model.config?.vision === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -194,10 +194,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -211,10 +211,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
</Badge>
</div>
</div>
@@ -396,11 +396,11 @@
</div>
<div class="p-3 rounded-lg border bg-muted/20">
<div class="flex items-center justify-between">
<Label class="text-xs text-muted-foreground">别名数量</Label>
<Tag class="w-4 h-4 text-muted-foreground" />
<Label class="text-xs text-muted-foreground">调用次数</Label>
<BarChart3 class="w-4 h-4 text-muted-foreground" />
</div>
<p class="text-2xl font-bold mt-1">
{{ model.alias_count || 0 }}
{{ model.usage_count || 0 }}
</p>
</div>
</div>
@@ -455,105 +455,153 @@
<template v-else-if="providers.length > 0">
<!-- 桌面端表格 -->
<Table class="hidden sm:table">
<TableHeader>
<TableRow class="border-b border-border/60 hover:bg-transparent">
<TableHead class="h-10 font-semibold">
Provider
</TableHead>
<TableHead class="w-[120px] h-10 font-semibold">
能力
</TableHead>
<TableHead class="w-[180px] h-10 font-semibold">
价格 ($/M)
</TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<TableRow
<TableHeader>
<TableRow class="border-b border-border/60 hover:bg-transparent">
<TableHead class="h-10 font-semibold">
Provider
</TableHead>
<TableHead class="w-[120px] h-10 font-semibold">
能力
</TableHead>
<TableHead class="w-[180px] h-10 font-semibold">
价格 ($/M)
</TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<TableRow
v-for="provider in providers"
:key="provider.id"
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
>
<TableCell class="py-3">
<div class="flex items-center gap-2">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
</TableCell>
<TableCell class="py-3">
<div class="flex gap-0.5">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
title="流式输出"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
title="工具调用"
/>
</div>
</TableCell>
<TableCell class="py-3">
<div class="text-xs font-mono space-y-0.5">
<!-- Token 计费输入/输出 -->
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
<span class="text-muted-foreground">输入/输出:</span>
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
<!-- 阶梯标记 -->
<span
v-if="(provider.tier_count || 1) > 1"
class="ml-1 text-muted-foreground"
title="阶梯计费"
>[阶梯]</span>
</div>
<!-- 缓存价格 -->
<div
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>缓存:</span>
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 1h 缓存价格 -->
<div
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>1h 缓存:</span>
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 按次计费 -->
<div v-if="(provider.price_per_request || 0) > 0">
<span class="text-muted-foreground">按次:</span>
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/</span>
</div>
<!-- 无定价 -->
<span
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
class="text-muted-foreground"
>-</span>
</div>
</TableCell>
<TableCell class="py-3 text-center">
<div class="flex items-center justify-center gap-1">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑此关联"
@click="$emit('editProvider', provider)"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
:title="provider.is_active ? '停用此关联' : '启用此关联'"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="删除此关联"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</TableCell>
</TableRow>
</TableBody>
</Table>
<!-- 移动端卡片列表 -->
<div class="sm:hidden divide-y divide-border/40">
<div
v-for="provider in providers"
:key="provider.id"
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
class="p-4 space-y-3"
>
<TableCell class="py-3">
<div class="flex items-center gap-2">
<div class="flex items-start justify-between gap-3">
<div class="flex items-center gap-2 min-w-0">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
</TableCell>
<TableCell class="py-3">
<div class="flex gap-0.5">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
title="流式输出"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
title="工具调用"
/>
</div>
</TableCell>
<TableCell class="py-3">
<div class="text-xs font-mono space-y-0.5">
<!-- Token 计费输入/输出 -->
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
<span class="text-muted-foreground">输入/输出:</span>
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
<!-- 阶梯标记 -->
<span
v-if="(provider.tier_count || 1) > 1"
class="ml-1 text-muted-foreground"
title="阶梯计费"
>[阶梯]</span>
</div>
<!-- 缓存价格 -->
<div
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>缓存:</span>
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 1h 缓存价格 -->
<div
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>1h 缓存:</span>
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 按次计费 -->
<div v-if="(provider.price_per_request || 0) > 0">
<span class="text-muted-foreground">按次:</span>
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/</span>
</div>
<!-- 无定价 -->
<span
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
class="text-muted-foreground"
>-</span>
</div>
</TableCell>
<TableCell class="py-3 text-center">
<div class="flex items-center justify-center gap-1">
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑此关联"
@click="$emit('editProvider', provider)"
>
<Edit class="w-3.5 h-3.5" />
@@ -562,7 +610,6 @@
variant="ghost"
size="icon"
class="h-7 w-7"
:title="provider.is_active ? '停用此关联' : '启用此关联'"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
@@ -571,82 +618,35 @@
variant="ghost"
size="icon"
class="h-7 w-7"
title="删除此关联"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</TableCell>
</TableRow>
</TableBody>
</Table>
<!-- 移动端卡片列表 -->
<div class="sm:hidden divide-y divide-border/40">
<div
v-for="provider in providers"
:key="provider.id"
class="p-4 space-y-3"
>
<div class="flex items-start justify-between gap-3">
<div class="flex items-center gap-2 min-w-0">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('editProvider', provider)"
<div class="flex items-center gap-3 text-xs">
<div class="flex gap-1">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
/>
</div>
<div
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
class="text-muted-foreground font-mono"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
</div>
</div>
</div>
<div class="flex items-center gap-3 text-xs">
<div class="flex gap-1">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
/>
</div>
<div
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
class="text-muted-foreground font-mono"
>
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
</div>
</div>
</div>
</div>
</template>
@@ -695,7 +695,8 @@ import {
Loader2,
RefreshCw,
Copy,
Layers
Layers,
BarChart3
} from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import Card from '@/components/ui/card.vue'

View File

@@ -9,7 +9,7 @@
>
<form
class="space-y-6"
@submit.prevent="handleSubmit"
@submit.prevent="handleSubmit()"
>
<!-- API 配置 -->
<div class="space-y-4">
@@ -132,6 +132,79 @@
</div>
</div>
</div>
<!-- 代理配置 -->
<div class="space-y-4">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch v-model="proxyEnabled" />
<span class="text-sm text-muted-foreground">启用代理</span>
</div>
</div>
<div
v-if="proxyEnabled"
class="space-y-4 rounded-lg border p-4"
>
<div class="space-y-2">
<Label for="proxy_url">代理 URL *</Label>
<Input
id="proxy_url"
v-model="form.proxy_url"
placeholder="http://host:port 或 socks5://host:port"
required
:class="proxyUrlError ? 'border-red-500' : ''"
/>
<p
v-if="proxyUrlError"
class="text-xs text-red-500"
>
{{ proxyUrlError }}
</p>
<p
v-else
class="text-xs text-muted-foreground"
>
支持 HTTPHTTPSSOCKS5 代理
</p>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="proxy_user">用户名可选</Label>
<Input
:id="`proxy_user_${formId}`"
:name="`proxy_user_${formId}`"
v-model="form.proxy_username"
placeholder="代理认证用户名"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-2">
<Label :for="`proxy_pass_${formId}`">密码可选</Label>
<Input
:id="`proxy_pass_${formId}`"
:name="`proxy_pass_${formId}`"
v-model="form.proxy_password"
type="text"
:placeholder="passwordPlaceholder"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
/>
</div>
</div>
</div>
</div>
</form>
<template #footer>
@@ -145,12 +218,24 @@
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
@click="handleSubmit"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
</Button>
</template>
</Dialog>
<!-- 确认清空凭据对话框 -->
<AlertDialog
v-model="showClearCredentialsDialog"
title="清空代理凭据"
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
type="warning"
confirm-text="清空并保存"
cancel-text="返回编辑"
@confirm="confirmClearCredentials"
@cancel="showClearCredentialsDialog = false"
/>
</template>
<script setup lang="ts">
@@ -165,7 +250,9 @@ import {
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue'
import { Link, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
@@ -194,6 +281,11 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
// 生成随机 ID 防止浏览器自动填充
const formId = Math.random().toString(36).substring(2, 10)
// 内部状态
const internalOpen = computed(() => props.modelValue)
@@ -207,7 +299,11 @@ const form = ref({
max_retries: 3,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true
is_active: true,
// 代理配置
proxy_url: '',
proxy_username: '',
proxy_password: '',
})
// API 格式列表
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
const hasExistingPassword = computed(() => {
if (!props.endpoint?.proxy) return false
const proxy = props.endpoint.proxy as { password?: string }
return proxy?.password === MASKED_PASSWORD
})
// 密码输入框的 placeholder
const passwordPlaceholder = computed(() => {
if (hasExistingPassword.value) {
return '已保存密码,留空保持不变'
}
return '代理认证密码'
})
// 代理 URL 验证
const proxyUrlError = computed(() => {
// 只有启用代理且填写了 URL 时才验证
if (!proxyEnabled.value || !form.value.proxy_url) {
return ''
}
const url = form.value.proxy_url.trim()
// 检查禁止的特殊字符
if (/[\n\r]/.test(url)) {
return '代理 URL 包含非法字符'
}
// 验证协议(不支持 SOCKS4
if (!/^(http|https|socks5):\/\//i.test(url)) {
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
}
try {
const parsed = new URL(url)
if (!parsed.host) {
return '代理 URL 必须包含有效的 host'
}
// 禁止 URL 中内嵌认证信息
if (parsed.username || parsed.password) {
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
}
} catch {
return '代理 URL 格式无效'
}
return ''
})
// 组件挂载时加载API格式
onMounted(() => {
loadApiFormats()
@@ -252,14 +395,23 @@ function resetForm() {
max_retries: 3,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true
is_active: true,
proxy_url: '',
proxy_username: '',
proxy_password: '',
}
proxyEnabled.value = false
}
// 原始密码占位符(后端返回的脱敏标记)
const MASKED_PASSWORD = '***'
// 加载端点数据(编辑模式)
function loadEndpointData() {
if (!props.endpoint) return
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
form.value = {
api_format: props.endpoint.api_format,
base_url: props.endpoint.base_url,
@@ -268,8 +420,15 @@ function loadEndpointData() {
max_retries: props.endpoint.max_retries,
max_concurrent: props.endpoint.max_concurrent || undefined,
rate_limit: props.endpoint.rate_limit || undefined,
is_active: props.endpoint.is_active
is_active: props.endpoint.is_active,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
}
// 根据 enabled 字段或 url 存在判断是否启用代理
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
}
// 使用 useFormDialog 统一处理对话框逻辑
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
resetForm,
})
// 构建代理配置
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
// - 无 URL 时返回 null
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
if (!form.value.proxy_url) {
// 没填 URL无代理配置
return null
}
return {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: proxyEnabled.value, // 开关状态决定是否启用
}
}
// 提交表单
const handleSubmit = async () => {
const handleSubmit = async (skipCredentialCheck = false) => {
if (!props.provider && !props.endpoint) return
// 只在开关开启且填写了 URL 时验证
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
showError(proxyUrlError.value, '代理配置错误')
return
}
// 检查:开关开启但没有 URL却有用户名或密码
const hasOrphanedCredentials = proxyEnabled.value
&& !form.value.proxy_url
&& (form.value.proxy_username || form.value.proxy_password)
if (hasOrphanedCredentials && !skipCredentialCheck) {
// 弹出确认对话框
showClearCredentialsDialog.value = true
return
}
loading.value = true
try {
const proxyConfig = buildProxyConfig()
if (isEditMode.value && props.endpoint) {
// 更新端点
await updateEndpoint(props.endpoint.id, {
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点已更新', '保存成功')
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点创建成功', '成功')
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
loading.value = false
}
}
// 确认清空凭据并继续保存
const confirmClearCredentials = () => {
form.value.proxy_username = ''
form.value.proxy_password = ''
showClearCredentialsDialog.value = false
handleSubmit(true) // 跳过凭据检查,直接提交
}
</script>

View File

@@ -117,8 +117,12 @@
class="text-center py-6 text-muted-foreground border rounded-lg border-dashed"
>
<Tag class="w-8 h-8 mx-auto mb-2 opacity-50" />
<p class="text-sm">未配置映射</p>
<p class="text-xs mt-1">将只使用主模型名称</p>
<p class="text-sm">
未配置映射
</p>
<p class="text-xs mt-1">
将只使用主模型名称
</p>
</div>
</div>
</div>

View File

@@ -312,8 +312,41 @@
<template #footer>
<div class="flex items-center justify-between w-full">
<div class="text-xs text-muted-foreground">
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
<div class="flex items-center gap-4">
<div class="text-xs text-muted-foreground">
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
</div>
<div class="flex items-center gap-2 pl-4 border-l border-border">
<span class="text-xs text-muted-foreground">调度:</span>
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'cache_affinity'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="优先使用已缓存的Provider利用Prompt Cache"
@click="schedulingMode = 'cache_affinity'"
>
缓存亲和
</button>
</div>
</div>
</div>
<div class="flex gap-2">
<Button
@@ -410,6 +443,9 @@ const saving = ref(false)
// Key 优先级编辑状态
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
// 调度模式状态
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
// 可用的 API 格式
const availableFormats = computed(() => {
return Object.keys(keysByFormat.value).sort()
@@ -433,11 +469,18 @@ watch(internalOpen, async (open) => {
// 加载当前的优先级模式配置
async function loadCurrentPriorityMode() {
try {
const response = await adminApi.getSystemConfig('provider_priority_mode')
const currentMode = response.value || 'provider'
const [priorityResponse, schedulingResponse] = await Promise.all([
adminApi.getSystemConfig('provider_priority_mode'),
adminApi.getSystemConfig('scheduling_mode')
])
const currentMode = priorityResponse.value || 'provider'
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
} catch {
activeMainTab.value = 'provider'
schedulingMode.value = 'cache_affinity'
}
}
@@ -611,11 +654,19 @@ async function save() {
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
await adminApi.updateSystemConfig(
'provider_priority_mode',
newMode,
'Provider/Key 优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)'
)
// 保存优先级模式和调度模式
await Promise.all([
adminApi.updateSystemConfig(
'provider_priority_mode',
newMode,
'Provider/Key 优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)'
),
adminApi.updateSystemConfig(
'scheduling_mode',
schedulingMode.value,
'调度模式fixed_order(固定顺序模式) 或 cache_affinity(缓存亲和模式)'
)
])
const providerUpdates = sortedProviders.value.map((provider, index) =>
updateProvider(provider.id, { provider_priority: index + 1 })

View File

@@ -526,7 +526,14 @@
@edit-model="handleEditModel"
@delete-model="handleDeleteModel"
@batch-assign="handleBatchAssign"
@manage-alias="handleManageAlias"
/>
<!-- 模型名称映射 -->
<ModelAliasesTab
v-if="provider"
:key="`aliases-${provider.id}`"
:provider="provider"
@refresh="handleRelatedDataRefresh"
/>
</div>
</template>
@@ -629,16 +636,6 @@
@update:open="batchAssignDialogOpen = $event"
@changed="handleBatchAssignChanged"
/>
<!-- 模型别名管理对话框 -->
<ModelAliasDialog
v-if="open && provider"
:open="aliasDialogOpen"
:provider-id="provider.id"
:model="aliasEditingModel"
@update:open="aliasDialogOpen = $event"
@saved="handleAliasSaved"
/>
</template>
<script setup lang="ts">
@@ -667,8 +664,8 @@ import {
KeyFormDialog,
KeyAllowedModelsDialog,
ModelsTab,
BatchAssignModelsDialog,
ModelAliasDialog
ModelAliasesTab,
BatchAssignModelsDialog
} from '@/features/providers/components'
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false)
// 别名管理相关状态
const aliasDialogOpen = ref(false)
const aliasEditingModel = ref<Model | null>(null)
// 拖动排序相关状态
const dragState = ref({
isDragging: false,
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value ||
deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value ||
aliasDialogOpen.value
batchAssignDialogOpen.value
)
// 监听 providerId 变化
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
keyAllowedModelsDialogOpen.value = false
deleteKeyConfirmOpen.value = false
batchAssignDialogOpen.value = false
aliasDialogOpen.value = false
// 重置临时数据
endpointToEdit.value = null
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
emit('refresh')
}
// 处理管理映射 - 打开别名对话框
function handleManageAlias(model: Model) {
aliasEditingModel.value = model
aliasDialogOpen.value = true
}
// 处理别名保存完成
async function handleAliasSaved() {
aliasEditingModel.value = null
await loadProvider()
emit('refresh')
}
// 处理模型保存完成
async function handleModelSaved() {
editingModel.value = null

View File

@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'

File diff suppressed because it is too large Load Diff

View File

@@ -165,15 +165,6 @@
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="管理映射"
@click="openAliasDialog(model)"
>
<Tag class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
@@ -218,7 +209,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast'
@@ -233,7 +224,6 @@ const emit = defineEmits<{
'editModel': [model: Model]
'deleteModel': [model: Model]
'batchAssign': []
'manageAlias': [model: Model]
}>()
const { error: showError, success: showSuccess } = useToast()
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
emit('batchAssign')
}
// 打开别名管理对话框
function openAliasDialog(model: Model) {
emit('manageAlias', model)
}
// 切换模型启用状态
async function toggleModelActive(model: Model) {
if (togglingModelId.value) return

View File

@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
return groups
})
// 计算链路总耗时(从第一个节点开始到最后一个节点结束
// 计算链路总耗时(使用成功候选的 latency_ms 字段
// 优先使用 latency_ms因为它与 Usage.response_time_ms 使用相同的时间基准
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
const totalTraceLatency = computed(() => {
if (!timeline.value || timeline.value.length === 0) return 0
// 查找成功的候选,使用其 latency_ms
const successCandidate = timeline.value.find(c => c.status === 'success')
if (successCandidate?.latency_ms != null) {
return successCandidate.latency_ms
}
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
if (failedWithLatency?.latency_ms != null) {
return failedWithLatency.latency_ms
}
// 回退:使用 finished_at - started_at 计算
let earliestStart: number | null = null
let latestEnd: number | null = null

View File

@@ -25,7 +25,7 @@
</h3>
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
<span>{{ detail?.model || '-' }}</span>
<template v-if="detail?.target_model">
<template v-if="detail?.target_model && detail.target_model !== detail.model">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"

View File

@@ -177,8 +177,9 @@
费用
</TableHead>
<TableHead class="h-12 font-semibold w-[70px] text-right">
<div class="inline-block max-w-[2rem] leading-tight">
响应时间
<div class="flex flex-col items-end text-xs gap-0.5">
<span>首字</span>
<span class="text-muted-foreground font-normal">总耗时</span>
</div>
</TableHead>
</TableRow>
@@ -356,15 +357,28 @@
</div>
</TableCell>
<TableCell class="text-right py-4 w-[70px]">
<span
<div
v-if="record.status === 'pending' || record.status === 'streaming'"
class="text-primary tabular-nums"
class="flex flex-col items-end text-xs gap-0.5"
>
{{ getElapsedTime(record) }}
</span>
<span v-else-if="record.response_time_ms">
{{ (record.response_time_ms / 1000).toFixed(2) }}s
</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
</div>
<span
v-else
class="text-muted-foreground"

View File

@@ -78,6 +78,7 @@ export interface UsageRecord {
cost: number
actual_cost?: number
response_time_ms?: number
first_byte_time_ms?: number // 首字时间 (TTFB)
is_stream: boolean
status_code?: number
error_message?: string

View File

@@ -611,41 +611,42 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
id: 'gm-001',
name: 'claude-haiku-4-5-20251001',
display_name: 'claude-haiku-4-5',
description: 'Anthropic 最快速的 Claude 4 系列模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.00, output_price_per_1m: 5.00, cache_creation_price_per_1m: 1.25, cache_read_price_per_1m: 0.1 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 最快速的 Claude 4 系列模型'
},
provider_count: 3,
alias_count: 2,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-002',
name: 'claude-opus-4-5-20251101',
display_name: 'claude-opus-4-5',
description: 'Anthropic 最强大的模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 5.00, output_price_per_1m: 25.00, cache_creation_price_per_1m: 6.25, cache_read_price_per_1m: 0.5 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 最强大的模型'
},
provider_count: 2,
alias_count: 1,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-003',
name: 'claude-sonnet-4-5-20250929',
display_name: 'claude-sonnet-4-5',
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文',
is_active: true,
default_tiered_pricing: {
tiers: [
@@ -677,116 +678,124 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
}
]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文'
},
supported_capabilities: ['cache_1h', 'cli_1m'],
provider_count: 3,
alias_count: 2,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-004',
name: 'gemini-3-pro-image-preview',
display_name: 'gemini-3-pro-image-preview',
description: 'Google Gemini 3 Pro 图像生成预览版',
is_active: true,
default_price_per_request: 0.300,
default_tiered_pricing: {
tiers: []
},
default_supports_vision: true,
default_supports_function_calling: false,
default_supports_streaming: true,
default_supports_image_generation: true,
config: {
streaming: true,
vision: true,
function_calling: false,
image_generation: true,
description: 'Google Gemini 3 Pro 图像生成预览版'
},
provider_count: 1,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-005',
name: 'gemini-3-pro-preview',
display_name: 'gemini-3-pro-preview',
description: 'Google Gemini 3 Pro 预览版',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 2.00, output_price_per_1m: 12.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Google Gemini 3 Pro 预览版'
},
provider_count: 1,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-006',
name: 'gpt-5.1',
display_name: 'gpt-5.1',
description: 'OpenAI GPT-5.1 模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 模型'
},
provider_count: 2,
alias_count: 1,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-007',
name: 'gpt-5.1-codex',
display_name: 'gpt-5.1-codex',
description: 'OpenAI GPT-5.1 Codex 代码专用模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex 代码专用模型'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-008',
name: 'gpt-5.1-codex-max',
display_name: 'gpt-5.1-codex-max',
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-009',
name: 'gpt-5.1-codex-mini',
display_name: 'gpt-5.1-codex-mini',
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
}
]

View File

@@ -1000,17 +1000,11 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
id: m.id,
name: m.name,
display_name: m.display_name,
description: m.description,
icon_url: null,
is_active: m.is_active,
default_tiered_pricing: m.default_tiered_pricing,
default_price_per_request: null,
default_supports_vision: m.default_supports_vision,
default_supports_function_calling: m.default_supports_function_calling,
default_supports_streaming: m.default_supports_streaming,
default_supports_extended_thinking: m.default_supports_extended_thinking || false,
default_supports_image_generation: false,
supported_capabilities: null
default_price_per_request: m.default_price_per_request,
supported_capabilities: m.supported_capabilities,
config: m.config
})),
total: MOCK_GLOBAL_MODELS.length
})

View File

@@ -1169,4 +1169,26 @@ body[theme-mode='dark'] .literary-annotation {
.scrollbar-hide::-webkit-scrollbar {
display: none;
}
}
.scrollbar-thin {
scrollbar-width: thin;
scrollbar-color: hsl(var(--border)) transparent;
}
.scrollbar-thin::-webkit-scrollbar {
width: 6px;
}
.scrollbar-thin::-webkit-scrollbar-track {
background: transparent;
}
.scrollbar-thin::-webkit-scrollbar-thumb {
background-color: hsl(var(--border));
border-radius: 3px;
}
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
background-color: hsl(var(--muted-foreground) / 0.5);
}
}

View File

@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
const filteredApiKeys = computed(() => {
let result = apiKeys.value
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(key =>
(key.name && key.name.toLowerCase().includes(query)) ||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
(key.username && key.username.toLowerCase().includes(query)) ||
(key.user_email && key.user_email.toLowerCase().includes(query))
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(key => {
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 状态筛选

View File

@@ -142,32 +142,37 @@ async function resetAffinitySearch() {
await fetchAffinityList()
}
async function clearUserCache(identifier: string, displayName?: string) {
const target = identifier?.trim()
if (!target) {
showError('无法识别标识符')
async function clearSingleAffinity(item: UserAffinity) {
const affinityKey = item.affinity_key?.trim()
const endpointId = item.endpoint_id?.trim()
const modelId = item.global_model_id?.trim()
const apiFormat = item.api_format?.trim()
if (!affinityKey || !endpointId || !modelId || !apiFormat) {
showError('缓存记录信息不完整,无法删除')
return
}
const label = displayName || target
const label = item.user_api_key_name || affinityKey
const modelLabel = item.model_display_name || item.model_name || modelId
const confirmed = await showConfirm({
title: '确认清除',
message: `确定要清除 ${label} 的缓存吗?`,
message: `确定要清除 ${label} 在模型 ${modelLabel} 上的缓存亲和性吗?`,
confirmText: '确认清除',
variant: 'destructive'
})
if (!confirmed) return
clearingRowAffinityKey.value = target
clearingRowAffinityKey.value = affinityKey
try {
await cacheApi.clearUserCache(target)
await cacheApi.clearSingleAffinity(affinityKey, endpointId, modelId, apiFormat)
showSuccess('清除成功')
await fetchCacheStats()
await fetchAffinityList(tableKeyword.value.trim() || undefined)
} catch (error) {
showError('清除失败')
log.error('清除用户缓存失败', error)
log.error('清除单条缓存失败', error)
} finally {
clearingRowAffinityKey.value = null
}
@@ -618,7 +623,7 @@ onBeforeUnmount(() => {
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
:disabled="clearingRowAffinityKey === item.affinity_key"
title="清除缓存"
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
@click="clearSingleAffinity(item)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
@@ -668,7 +673,7 @@ onBeforeUnmount(() => {
variant="ghost"
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive shrink-0"
:disabled="clearingRowAffinityKey === item.affinity_key"
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
@click="clearSingleAffinity(item)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
@@ -935,7 +940,10 @@ onBeforeUnmount(() => {
:key="`${index}-${aliasIndex}`"
>
<TableCell>
<Badge variant="outline" class="text-xs">
<Badge
variant="outline"
class="text-xs"
>
{{ mapping.provider_name }}
</Badge>
</TableCell>
@@ -981,7 +989,10 @@ onBeforeUnmount(() => {
class="p-4 space-y-2"
>
<div class="flex items-center justify-between">
<Badge variant="outline" class="text-xs">
<Badge
variant="outline"
class="text-xs"
>
{{ mapping.provider_name }}
</Badge>
<div class="flex items-center gap-2">

View File

@@ -111,9 +111,6 @@
<TableHead class="w-[80px] text-center">
提供商
</TableHead>
<TableHead class="w-[70px] text-center">
别名/映射
</TableHead>
<TableHead class="w-[80px] text-center">
调用次数
</TableHead>
@@ -128,7 +125,7 @@
<TableBody>
<TableRow v-if="loading">
<TableCell
colspan="8"
colspan="7"
class="text-center py-8"
>
<Loader2 class="w-6 h-6 animate-spin mx-auto" />
@@ -136,7 +133,7 @@
</TableRow>
<TableRow v-else-if="filteredGlobalModels.length === 0">
<TableCell
colspan="8"
colspan="7"
class="text-center py-8 text-muted-foreground"
>
没有找到匹配的模型
@@ -171,27 +168,27 @@
<div class="space-y-1 w-fit">
<div class="flex flex-wrap gap-1">
<Zap
v-if="model.default_supports_streaming"
v-if="model.config?.streaming !== false"
class="w-4 h-4 text-muted-foreground"
title="流式输出"
/>
<Image
v-if="model.default_supports_image_generation"
v-if="model.config?.image_generation === true"
class="w-4 h-4 text-muted-foreground"
title="图像生成"
/>
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
title="工具调用"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
title="深度思考"
/>
@@ -244,11 +241,6 @@
{{ model.provider_count || 0 }}
</Badge>
</TableCell>
<TableCell class="text-center">
<Badge variant="secondary">
{{ model.alias_count || 0 }}
</Badge>
</TableCell>
<TableCell class="text-center">
<span class="text-sm font-mono">{{ formatUsageCount(model.usage_count || 0) }}</span>
</TableCell>
@@ -369,23 +361,23 @@
<!-- 第二行能力图标 -->
<div class="flex flex-wrap gap-1.5">
<Zap
v-if="model.default_supports_streaming"
v-if="model.config?.streaming !== false"
class="w-4 h-4 text-muted-foreground"
/>
<Image
v-if="model.default_supports_image_generation"
v-if="model.config?.image_generation === true"
class="w-4 h-4 text-muted-foreground"
/>
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
/>
</div>
@@ -393,7 +385,6 @@
<!-- 第三行统计信息 -->
<div class="flex flex-wrap items-center gap-3 text-xs text-muted-foreground">
<span>提供商 {{ model.provider_count || 0 }}</span>
<span>别名 {{ model.alias_count || 0 }}</span>
<span>调用 {{ formatUsageCount(model.usage_count || 0) }}</span>
<span
v-if="getFirstTierPrice(model, 'input') || getFirstTierPrice(model, 'output')"
@@ -1011,30 +1002,30 @@ async function batchRemoveSelectedProviders() {
const filteredGlobalModels = computed(() => {
let result = globalModels.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选
if (capabilityFilters.value.streaming) {
result = result.filter(m => m.default_supports_streaming)
result = result.filter(m => m.config?.streaming !== false)
}
if (capabilityFilters.value.imageGeneration) {
result = result.filter(m => m.default_supports_image_generation)
result = result.filter(m => m.config?.image_generation === true)
}
if (capabilityFilters.value.vision) {
result = result.filter(m => m.default_supports_vision)
result = result.filter(m => m.config?.vision === true)
}
if (capabilityFilters.value.toolUse) {
result = result.filter(m => m.default_supports_function_calling)
result = result.filter(m => m.config?.function_calling === true)
}
if (capabilityFilters.value.extendedThinking) {
result = result.filter(m => m.default_supports_extended_thinking)
result = result.filter(m => m.config?.extended_thinking === true)
}
return result

View File

@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
const filteredProviders = computed(() => {
let result = [...providers.value]
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value.trim()) {
const query = searchQuery.value.toLowerCase()
result = result.filter(p =>
p.display_name.toLowerCase().includes(query) ||
p.name.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(p => {
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 排序

View File

@@ -7,6 +7,7 @@
<template #actions>
<Button
:disabled="loading"
class="shadow-none hover:shadow-none"
@click="saveSystemConfig"
>
{{ loading ? '保存中...' : '保存所有配置' }}
@@ -15,6 +16,94 @@
</PageHeader>
<div class="mt-6 space-y-6">
<!-- 配置导出/导入 -->
<CardSection
title="配置管理"
description="导出或导入提供商和模型配置,便于备份或迁移"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出当前所有提供商端点API Key 和模型配置到 JSON 文件
</p>
<Button
variant="outline"
:disabled="exportLoading"
@click="handleExportConfig"
>
<Download class="w-4 h-4 mr-2" />
{{ exportLoading ? '导出中...' : '导出配置' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入配置支持跳过覆盖或报错三种冲突处理模式
</p>
<div class="flex items-center gap-2">
<input
ref="configFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleConfigFileSelect"
>
<Button
variant="outline"
:disabled="importLoading"
@click="triggerConfigFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importLoading ? '导入中...' : '导入配置' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 用户数据导出/导入 -->
<CardSection
title="用户数据管理"
description="导出或导入用户及其 API Keys 数据(不含管理员)"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出所有普通用户及其 API Keys JSON 文件
</p>
<Button
variant="outline"
:disabled="exportUsersLoading"
@click="handleExportUsers"
>
<Download class="w-4 h-4 mr-2" />
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入用户数据需相同 ENCRYPTION_KEY
</p>
<div class="flex items-center gap-2">
<input
ref="usersFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleUsersFileSelect"
>
<Button
variant="outline"
:disabled="importUsersLoading"
@click="triggerUsersFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 基础配置 -->
<CardSection
title="基础配置"
@@ -96,32 +185,13 @@
</div>
</CardSection>
<!-- API Key 管理配置 -->
<!-- 独立余额 Key 过期管理 -->
<CardSection
title="API Key 管理"
description="API Key 相关配置"
title="独立余额 Key 过期管理"
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div>
<Label
for="api-key-expire"
class="block text-sm font-medium"
>
API密钥过期天数
</Label>
<Input
id="api-key-expire"
v-model.number="systemConfig.api_key_expire_days"
type="number"
placeholder="0"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
0 表示永不过期
</p>
</div>
<div class="flex items-center h-full pt-6">
<div class="flex items-center h-full">
<div class="flex items-center space-x-2">
<Checkbox
id="auto-delete-expired-keys"
@@ -135,7 +205,7 @@
自动删除过期 Key
</Label>
<p class="text-xs text-muted-foreground">
关闭时仅禁用过期 Key
关闭时仅禁用过期 Key不会物理删除
</p>
</div>
</div>
@@ -359,6 +429,25 @@
避免单次操作过大影响性能
</p>
</div>
<div>
<Label
for="audit-log-retention-days"
class="block text-sm font-medium"
>
审计日志保留天数
</Label>
<Input
id="audit-log-retention-days"
v-model.number="systemConfig.audit_log_retention_days"
type="number"
placeholder="30"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
超过后删除审计日志记录
</p>
</div>
</div>
<!-- 清理策略说明 -->
@@ -371,15 +460,318 @@
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储节省空间</p>
<p>3. <strong>统计阶段</strong>: 仅保留 tokens成本等统计信息</p>
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
<p>5. <strong>审计日志</strong>: 独立清理记录用户登录操作等安全事件</p>
</div>
</div>
</CardSection>
</div>
<!-- 导入配置对话框 -->
<Dialog
v-model:open="importDialogOpen"
title="导入配置"
description="选择冲突处理模式并确认导入"
>
<div class="space-y-4">
<div
v-if="importPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
配置预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>全局模型: {{ importPreview.global_models?.length || 0 }} </li>
<li>提供商: {{ importPreview.providers?.length || 0 }} </li>
<li>
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }}
</li>
<li>
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select
v-model="mergeMode"
v-model:open="mergeModeSelectOpen"
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有配置
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入配置替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="mergeMode === 'skip'">
已存在的配置将被保留仅导入新配置
</template>
<template v-else-if="mergeMode === 'overwrite'">
已存在的配置将被导入的配置覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<p class="text-xs text-muted-foreground">
注意相同的 API Keys 会自动跳过不会创建重复记录
</p>
</div>
<template #footer>
<Button
variant="outline"
@click="importDialogOpen = false; mergeModeSelectOpen = false"
>
取消
</Button>
<Button
:disabled="importLoading"
@click="confirmImport"
>
{{ importLoading ? '导入中...' : '确认导入' }}
</Button>
</template>
</Dialog>
<!-- 导入结果对话框 -->
<Dialog
v-model:open="importResultDialogOpen"
title="导入完成"
>
<div
v-if="importResult"
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
全局模型
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.global_models.created }},
更新: {{ importResult.stats.global_models.updated }},
跳过: {{ importResult.stats.global_models.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
提供商
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.providers.created }},
更新: {{ importResult.stats.providers.updated }},
跳过: {{ importResult.stats.providers.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
端点
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.endpoints.created }},
更新: {{ importResult.stats.endpoints.updated }},
跳过: {{ importResult.stats.endpoints.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.keys.created }},
跳过: {{ importResult.stats.keys.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg col-span-2">
<p class="font-medium">
模型配置
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.models.created }},
更新: {{ importResult.stats.models.updated }},
跳过: {{ importResult.stats.models.skipped }}
</p>
</div>
</div>
<div
v-if="importResult.stats.errors.length > 0"
class="p-3 bg-destructive/10 rounded-lg"
>
<p class="font-medium text-destructive mb-2">
警告信息
</p>
<ul class="text-sm text-destructive space-y-1">
<li
v-for="(err, index) in importResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<template #footer>
<Button @click="importResultDialogOpen = false">
确定
</Button>
</template>
</Dialog>
<!-- 用户数据导入对话框 -->
<Dialog
v-model:open="importUsersDialogOpen"
title="导入用户数据"
description="选择冲突处理模式并确认导入"
>
<div class="space-y-4">
<div
v-if="importUsersPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
数据预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>用户: {{ importUsersPreview.users?.length || 0 }} </li>
<li>
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select
v-model="usersMergeMode"
v-model:open="usersMergeModeSelectOpen"
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有用户
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入数据替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="usersMergeMode === 'skip'">
已存在的用户将被保留仅导入新用户
</template>
<template v-else-if="usersMergeMode === 'overwrite'">
已存在的用户将被导入的数据覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<p class="text-xs text-muted-foreground">
注意用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作
</p>
</div>
<template #footer>
<Button
variant="outline"
@click="importUsersDialogOpen = false; usersMergeModeSelectOpen = false"
>
取消
</Button>
<Button
:disabled="importUsersLoading"
@click="confirmImportUsers"
>
{{ importUsersLoading ? '导入中...' : '确认导入' }}
</Button>
</template>
</Dialog>
<!-- 用户数据导入结果对话框 -->
<Dialog
v-model:open="importUsersResultDialogOpen"
title="用户数据导入完成"
>
<div
v-if="importUsersResult"
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
用户
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.users.created }},
更新: {{ importUsersResult.stats.users.updated }},
跳过: {{ importUsersResult.stats.users.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.api_keys.created }},
跳过: {{ importUsersResult.stats.api_keys.skipped }}
</p>
</div>
</div>
<div
v-if="importUsersResult.stats.errors.length > 0"
class="p-3 bg-destructive/10 rounded-lg"
>
<p class="font-medium text-destructive mb-2">
警告信息
</p>
<ul class="text-sm text-destructive space-y-1">
<li
v-for="(err, index) in importUsersResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<template #footer>
<Button @click="importUsersResultDialogOpen = false">
确定
</Button>
</template>
</Dialog>
</PageContainer>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Download, Upload } from 'lucide-vue-next'
import Button from '@/components/ui/button.vue'
import Input from '@/components/ui/input.vue'
import Label from '@/components/ui/label.vue'
@@ -389,9 +781,12 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
import SelectValue from '@/components/ui/select-value.vue'
import SelectContent from '@/components/ui/select-content.vue'
import SelectItem from '@/components/ui/select-item.vue'
import {
Dialog,
} from '@/components/ui'
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
import { useToast } from '@/composables/useToast'
import { adminApi } from '@/api/admin'
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
import { log } from '@/utils/logger'
const { success, error } = useToast()
@@ -403,8 +798,7 @@ interface SystemConfig {
// 用户注册
enable_registration: boolean
require_email_verification: boolean
// API Key 管理
api_key_expire_days: number
// 独立余额 Key 过期管理
auto_delete_expired_keys: boolean
// 日志记录
request_log_level: string
@@ -418,11 +812,34 @@ interface SystemConfig {
header_retention_days: number
log_retention_days: number
cleanup_batch_size: number
audit_log_retention_days: number
}
const loading = ref(false)
const logLevelSelectOpen = ref(false)
// 导出/导入相关
const exportLoading = ref(false)
const importLoading = ref(false)
const importDialogOpen = ref(false)
const importResultDialogOpen = ref(false)
const configFileInput = ref<HTMLInputElement | null>(null)
const importPreview = ref<ConfigExportData | null>(null)
const importResult = ref<ConfigImportResponse | null>(null)
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const mergeModeSelectOpen = ref(false)
// 用户数据导出/导入相关
const exportUsersLoading = ref(false)
const importUsersLoading = ref(false)
const importUsersDialogOpen = ref(false)
const importUsersResultDialogOpen = ref(false)
const usersFileInput = ref<HTMLInputElement | null>(null)
const importUsersPreview = ref<UsersExportData | null>(null)
const importUsersResult = ref<UsersImportResponse | null>(null)
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const usersMergeModeSelectOpen = ref(false)
const systemConfig = ref<SystemConfig>({
// 基础配置
default_user_quota_usd: 10.0,
@@ -430,8 +847,7 @@ const systemConfig = ref<SystemConfig>({
// 用户注册
enable_registration: false,
require_email_verification: false,
// API Key 管理
api_key_expire_days: 0,
// 独立余额 Key 过期管理
auto_delete_expired_keys: false,
// 日志记录
request_log_level: 'basic',
@@ -445,6 +861,7 @@ const systemConfig = ref<SystemConfig>({
header_retention_days: 90,
log_retention_days: 365,
cleanup_batch_size: 1000,
audit_log_retention_days: 30,
})
// 计算属性KB 和 字节 之间的转换
@@ -486,8 +903,7 @@ async function loadSystemConfig() {
// 用户注册
'enable_registration',
'require_email_verification',
// API Key 管理
'api_key_expire_days',
// 独立余额 Key 过期管理
'auto_delete_expired_keys',
// 日志记录
'request_log_level',
@@ -501,6 +917,7 @@ async function loadSystemConfig() {
'header_retention_days',
'log_retention_days',
'cleanup_batch_size',
'audit_log_retention_days',
]
for (const key of configs) {
@@ -545,12 +962,7 @@ async function saveSystemConfig() {
value: systemConfig.value.require_email_verification,
description: '是否需要邮箱验证'
},
// API Key 管理
{
key: 'api_key_expire_days',
value: systemConfig.value.api_key_expire_days,
description: 'API密钥过期天数'
},
// 独立余额 Key 过期管理
{
key: 'auto_delete_expired_keys',
value: systemConfig.value.auto_delete_expired_keys,
@@ -608,6 +1020,11 @@ async function saveSystemConfig() {
value: systemConfig.value.cleanup_batch_size,
description: '每批次清理的记录数'
},
{
key: 'audit_log_retention_days',
value: systemConfig.value.audit_log_retention_days,
description: '审计日志保留天数'
},
]
const promises = configItems.map(item =>
@@ -623,4 +1040,185 @@ async function saveSystemConfig() {
loading.value = false
}
}
// 导出配置
async function handleExportConfig() {
exportLoading.value = true
try {
const data = await adminApi.exportConfig()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('配置已导出')
} catch (err) {
error('导出配置失败')
log.error('导出配置失败:', err)
} finally {
exportLoading.value = false
}
}
// 触发文件选择
function triggerConfigFileSelect() {
configFileInput.value?.click()
}
// 文件大小限制 (10MB)
const MAX_FILE_SIZE = 10 * 1024 * 1024
// 处理文件选择
function handleConfigFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as ConfigExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importPreview.value = data
mergeMode.value = 'skip'
importDialogOpen.value = true
} catch (err) {
error('解析配置文件失败,请确保是有效的 JSON 文件')
log.error('解析配置文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入
async function confirmImport() {
if (!importPreview.value) return
importLoading.value = true
try {
const result = await adminApi.importConfig({
...importPreview.value,
merge_mode: mergeMode.value
})
importResult.value = result
importDialogOpen.value = false
mergeModeSelectOpen.value = false
importResultDialogOpen.value = true
success('配置导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入配置失败')
log.error('导入配置失败:', err)
} finally {
importLoading.value = false
}
}
// 导出用户数据
async function handleExportUsers() {
exportUsersLoading.value = true
try {
const data = await adminApi.exportUsers()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('用户数据已导出')
} catch (err) {
error('导出用户数据失败')
log.error('导出用户数据失败:', err)
} finally {
exportUsersLoading.value = false
}
}
// 触发用户数据文件选择
function triggerUsersFileSelect() {
usersFileInput.value?.click()
}
// 处理用户数据文件选择
function handleUsersFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as UsersExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importUsersPreview.value = data
usersMergeMode.value = 'skip'
importUsersDialogOpen.value = true
} catch (err) {
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
log.error('解析用户数据文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入用户数据
async function confirmImportUsers() {
if (!importUsersPreview.value) return
importUsersLoading.value = true
try {
const result = await adminApi.importUsers({
...importUsersPreview.value,
merge_mode: usersMergeMode.value
})
importUsersResult.value = result
importUsersDialogOpen.value = false
usersMergeModeSelectOpen.value = false
importUsersResultDialogOpen.value = true
success('用户数据导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入用户数据失败')
log.error('导入用户数据失败:', err)
} finally {
importUsersLoading.value = false
}
}
</script>

View File

@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
})
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
filtered = filtered.filter(
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
filtered = filtered.filter(u => {
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
if (filterRole.value !== 'all') {

View File

@@ -103,7 +103,7 @@
</div>
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
平均响应
@@ -114,7 +114,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-kraft/30">
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
错误率
@@ -128,7 +128,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
转移次数
@@ -142,7 +142,7 @@
v-if="costStats"
class="relative p-3 sm:p-4 border-manilla/40"
>
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
实际成本
@@ -180,7 +180,7 @@
</div>
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存命中率
@@ -191,7 +191,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-kraft/30">
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存读取
@@ -202,7 +202,7 @@
</div>
</Card>
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
缓存创建
@@ -216,7 +216,7 @@
v-if="tokenBreakdown"
class="relative p-3 sm:p-4 border-manilla/40"
>
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
总Token
@@ -254,16 +254,16 @@
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
<div
v-if="loadingAnnouncements"
class="py-8 text-center"
class="flex-1 flex items-center justify-center"
>
<Loader2 class="h-5 w-5 animate-spin mx-auto text-muted-foreground" />
<Loader2 class="h-5 w-5 animate-spin text-muted-foreground" />
</div>
<div
v-else-if="announcements.length === 0"
class="py-8 text-center"
class="flex-1 flex flex-col items-center justify-center"
>
<Bell class="h-8 w-8 mx-auto text-muted-foreground/40" />
<Bell class="h-8 w-8 text-muted-foreground/40" />
<p class="mt-2 text-xs text-muted-foreground">
暂无公告
</p>
@@ -793,9 +793,8 @@ const statCardGlows = [
'bg-kraft/30'
]
const getStatIconColor = (index: number): string => {
const colors = ['text-book-cloth', 'text-kraft', 'text-book-cloth', 'text-kraft']
return colors[index % colors.length]
const getStatIconColor = (_index: number): string => {
return 'text-muted-foreground'
}
// 统计数据

View File

@@ -226,8 +226,8 @@
<div
v-for="announcement in announcements"
:key="announcement.id"
class="p-4 space-y-2 cursor-pointer transition-colors"
:class="[
'p-4 space-y-2 cursor-pointer transition-colors',
announcement.is_read ? 'hover:bg-muted/30' : 'bg-primary/5 hover:bg-primary/10'
]"
@click="viewAnnouncementDetail(announcement)"

View File

@@ -165,17 +165,17 @@
<TableCell class="py-4">
<div class="flex gap-1.5">
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
title="Vision"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
title="Tool Use"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
title="Extended Thinking"
/>
@@ -253,15 +253,15 @@
<!-- 第二行能力图标 -->
<div class="flex gap-1.5">
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
/>
</div>
@@ -474,24 +474,24 @@ async function toggleCapability(modelName: string, capName: string) {
const filteredModels = computed(() => {
let result = models.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选
if (capabilityFilters.value.vision) {
result = result.filter(m => m.default_supports_vision)
result = result.filter(m => m.config?.vision === true)
}
if (capabilityFilters.value.toolUse) {
result = result.filter(m => m.default_supports_function_calling)
result = result.filter(m => m.config?.function_calling === true)
}
if (capabilityFilters.value.extendedThinking) {
result = result.filter(m => m.default_supports_extended_thinking)
result = result.filter(m => m.config?.extended_thinking === true)
}
return result

View File

@@ -62,6 +62,7 @@
<Button
type="submit"
:disabled="savingProfile"
class="shadow-none hover:shadow-none"
>
{{ savingProfile ? '保存中...' : '保存修改' }}
</Button>
@@ -107,6 +108,7 @@
<Button
type="submit"
:disabled="changingPassword"
class="shadow-none hover:shadow-none"
>
{{ changingPassword ? '修改中...' : '修改密码' }}
</Button>
@@ -320,6 +322,7 @@
import { ref, onMounted } from 'vue'
import { useAuthStore } from '@/stores/auth'
import { meApi, type Profile } from '@/api/me'
import { useDarkMode, type ThemeMode } from '@/composables/useDarkMode'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue'
@@ -338,6 +341,7 @@ import { log } from '@/utils/logger'
const authStore = useAuthStore()
const { success, error: showError } = useToast()
const { setThemeMode } = useDarkMode()
const profile = ref<Profile | null>(null)
@@ -375,20 +379,8 @@ function handleThemeChange(value: string) {
themeSelectOpen.value = false
updatePreferences()
// 应用主题
if (value === 'dark') {
document.documentElement.classList.add('dark')
} else if (value === 'light') {
document.documentElement.classList.remove('dark')
} else {
// system: 跟随系统
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches
if (prefersDark) {
document.documentElement.classList.add('dark')
} else {
document.documentElement.classList.remove('dark')
}
}
// 使用 useDarkMode 统一切换主题
setThemeMode(value as ThemeMode)
}
function handleLanguageChange(value: string) {
@@ -418,10 +410,16 @@ async function loadProfile() {
async function loadPreferences() {
try {
const prefs = await meApi.getPreferences()
// 主题以本地 localStorage 为准useDarkMode 在应用启动时已初始化)
// 这样可以避免刷新页面时主题被服务端旧值覆盖
const { themeMode: currentThemeMode } = useDarkMode()
const localTheme = currentThemeMode.value
preferencesForm.value = {
avatar_url: prefs.avatar_url || '',
bio: prefs.bio || '',
theme: prefs.theme || 'light',
theme: localTheme, // 使用本地主题,而非服务端返回值
language: prefs.language || 'zh-CN',
timezone: prefs.timezone || 'Asia/Shanghai',
notifications: {
@@ -431,11 +429,12 @@ async function loadPreferences() {
}
}
// 应用主题
if (preferencesForm.value.theme === 'dark') {
document.documentElement.classList.add('dark')
} else if (preferencesForm.value.theme === 'light') {
document.documentElement.classList.remove('dark')
// 如果本地主题和服务端不一致,同步到服务端(静默更新,不提示用户)
const serverTheme = prefs.theme || 'light'
if (localTheme !== serverTheme) {
meApi.updatePreferences({ theme: localTheme }).catch(() => {
// 静默失败,不影响用户体验
})
}
} catch (error) {
log.error('加载偏好设置失败:', error)

View File

@@ -38,10 +38,10 @@
</button>
</div>
<p
v-if="model.description"
v-if="model.config?.description"
class="text-xs text-muted-foreground"
>
{{ model.description }}
{{ model.config?.description }}
</p>
</div>
<Button
@@ -73,10 +73,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -90,10 +90,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -107,10 +107,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
:variant="model.config?.vision === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
{{ model.config?.vision === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -124,10 +124,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -141,10 +141,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
</Badge>
</div>
</div>

12
migrate.sh Executable file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
set -e
CONTAINER_NAME="aether-app"
echo "Running database migrations in container: $CONTAINER_NAME"
docker exec $CONTAINER_NAME alembic upgrade head
echo "Database migration completed successfully"

View File

@@ -3,10 +3,8 @@
A proxy server that enables AI models to work with multiple API providers.
"""
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# 注意: dotenv 加载已统一移至 src/config/settings.py
# 不要在此处重复加载
try:
from src._version import __version__

View File

@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_query import router as provider_query_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
router.include_router(provider_query_router)
__all__ = ["router"]

View File

@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
allowed_providers=self.key_data.allowed_providers,
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit or 100,
rate_limit=self.key_data.rate_limit, # None 表示不限制
expire_days=self.key_data.expire_days,
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key

View File

@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
"""对代理配置中的密码进行脱敏处理"""
if not proxy_config:
return None
masked = dict(proxy_config)
if masked.get("password"):
masked["password"] = "***"
return masked
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
"proxy": mask_proxy_password(endpoint.proxy),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
created_at=now,
updated_at=now,
)
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
proxy=mask_proxy_password(new_endpoint.proxy),
total_keys=0,
active_keys=0,
)
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
proxy=mask_proxy_password(endpoint_obj.proxy),
total_keys=total_keys,
active_keys=active_keys,
)
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
if "proxy" in update_data:
if update_data["proxy"] is not None:
new_proxy = dict(update_data["proxy"])
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
if "password" not in new_proxy and endpoint.proxy:
old_password = endpoint.proxy.get("password")
if old_password:
new_proxy["password"] = old_password
update_data["proxy"] = new_proxy
# proxy 为 None 时保留,用于清除代理配置
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
proxy=mask_proxy_password(endpoint.proxy),
total_keys=total_keys,
active_keys=active_keys,
)

View File

@@ -5,6 +5,7 @@
from fastapi import APIRouter
from .catalog import router as catalog_router
from .external import router as external_router
from .global_models import router as global_models_router
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
@@ -12,3 +13,4 @@ router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"]
# 挂载子路由
router.include_router(catalog_router)
router.include_router(global_models_router)
router.include_router(external_router)

View File

@@ -72,10 +72,12 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
for gm in global_models:
gm_id = gm.id
provider_entries: List[ModelCatalogProviderDetail] = []
# 从 config JSON 读取能力标志
gm_config = gm.config or {}
capability_flags = {
"supports_vision": gm.default_supports_vision or False,
"supports_function_calling": gm.default_supports_function_calling or False,
"supports_streaming": gm.default_supports_streaming or False,
"supports_vision": gm_config.get("vision", False),
"supports_function_calling": gm_config.get("function_calling", False),
"supports_streaming": gm_config.get("streaming", True),
}
# 遍历该 GlobalModel 的所有关联提供商
@@ -140,7 +142,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
ModelCatalogItem(
global_model_name=gm.name,
display_name=gm.display_name,
description=gm.description,
description=gm_config.get("description"),
providers=provider_entries,
price_range=price_range,
total_providers=len(provider_entries),

View File

@@ -0,0 +1,141 @@
"""
models.dev 外部模型数据代理
"""
import json
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from src.clients import get_redis_client
from src.core.logger import logger
from src.models.database import User
from src.utils.auth_utils import require_admin
router = APIRouter()
CACHE_KEY = "aether:external:models_dev"
CACHE_TTL = 15 * 60 # 15 分钟
# 标记官方/一手提供商,前端可据此过滤第三方转售商
OFFICIAL_PROVIDERS = {
"anthropic", # Claude 官方
"openai", # OpenAI 官方
"google", # Gemini 官方
"google-vertex", # Google Vertex AI
"azure", # Azure OpenAI
"amazon-bedrock", # AWS Bedrock
"xai", # Grok 官方
"meta", # Llama 官方
"deepseek", # DeepSeek 官方
"mistral", # Mistral 官方
"cohere", # Cohere 官方
"zhipuai", # 智谱 AI 官方
"alibaba", # 阿里云(通义千问)
"minimax", # MiniMax 官方
"moonshot", # 月之暗面Kimi
"baichuan", # 百川智能
"ai21", # AI21 Labs
}
async def _get_cached_data() -> Optional[dict[str, Any]]:
"""从 Redis 获取缓存数据"""
redis = await get_redis_client()
if redis is None:
return None
try:
cached = await redis.get(CACHE_KEY)
if cached:
result: dict[str, Any] = json.loads(cached)
return result
except Exception as e:
logger.warning(f"读取 models.dev 缓存失败: {e}")
return None
async def _set_cached_data(data: dict) -> None:
"""将数据写入 Redis 缓存"""
redis = await get_redis_client()
if redis is None:
return
try:
await redis.setex(CACHE_KEY, CACHE_TTL, json.dumps(data, ensure_ascii=False))
except Exception as e:
logger.warning(f"写入 models.dev 缓存失败: {e}")
def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]:
"""为每个提供商标记是否为官方"""
result = {}
for provider_id, provider_data in data.items():
result[provider_id] = {
**provider_data,
"official": provider_id in OFFICIAL_PROVIDERS,
}
return result
@router.get("/external")
async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
"""
获取 models.dev 的模型数据(代理请求,解决跨域问题)
数据缓存 15 分钟(使用 Redis多 worker 共享)
每个提供商会标记 official 字段,前端可据此过滤
"""
# 检查缓存
cached = await _get_cached_data()
if cached is not None:
# 兼容旧缓存:如果没有 official 字段则补全并回写
try:
needs_mark = False
for provider_data in cached.values():
if not isinstance(provider_data, dict) or "official" not in provider_data:
needs_mark = True
break
if needs_mark:
marked_cached = _mark_official_providers(cached)
await _set_cached_data(marked_cached)
return JSONResponse(content=marked_cached)
except Exception as e:
logger.warning(f"处理 models.dev 缓存数据失败,将直接返回原缓存: {e}")
return JSONResponse(content=cached)
# 从 models.dev 获取数据
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get("https://models.dev/api.json")
response.raise_for_status()
data = response.json()
# 标记官方提供商
marked_data = _mark_official_providers(data)
# 写入缓存
await _set_cached_data(marked_data)
return JSONResponse(content=marked_data)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求 models.dev 超时")
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=502, detail=f"models.dev 返回错误: {e.response.status_code}"
)
except Exception as e:
raise HTTPException(status_code=502, detail=f"获取外部模型数据失败: {str(e)}")
@router.delete("/external/cache")
async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict:
"""清除 models.dev 缓存"""
redis = await get_redis_client()
if redis is None:
return {"cleared": False, "message": "Redis 未启用"}
try:
await redis.delete(CACHE_KEY)
return {"cleared": True}
except Exception as e:
raise HTTPException(status_code=500, detail=f"清除缓存失败: {str(e)}")

View File

@@ -187,21 +187,15 @@ class AdminCreateGlobalModelAdapter(AdminApiAdapter):
db=context.db,
name=self.payload.name,
display_name=self.payload.display_name,
description=self.payload.description,
official_url=self.payload.official_url,
icon_url=self.payload.icon_url,
is_active=self.payload.is_active,
# 按次计费配置
default_price_per_request=self.payload.default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=tiered_pricing_dict,
# 默认能力配置
default_supports_vision=self.payload.default_supports_vision,
default_supports_function_calling=self.payload.default_supports_function_calling,
default_supports_streaming=self.payload.default_supports_streaming,
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=self.payload.supported_capabilities,
# 模型配置JSON
config=self.payload.config,
)
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")

View File

@@ -21,7 +21,8 @@ from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
from src.services.cache.affinity_manager import get_affinity_manager
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
from src.services.cache.aware_scheduler import CacheAwareScheduler, get_cache_aware_scheduler
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
pipeline = ApiRequestPipeline()
@@ -185,6 +186,30 @@ async def clear_user_cache(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/affinity/{affinity_key}/{endpoint_id}/{model_id}/{api_format}")
async def clear_single_affinity(
affinity_key: str,
endpoint_id: str,
model_id: str,
api_format: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
Clear a single cache affinity entry
Parameters:
- affinity_key: API Key ID
- endpoint_id: Endpoint ID
- model_id: Model ID (GlobalModel ID)
- api_format: API format (claude/openai)
"""
adapter = AdminClearSingleAffinityAdapter(
affinity_key=affinity_key, endpoint_id=endpoint_id, model_id=model_id, api_format=api_format
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("")
async def clear_all_cache(
request: Request,
@@ -250,7 +275,22 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
# 读取系统配置,确保监控接口与编排器使用一致的模式
priority_mode = SystemConfigService.get_config(
context.db,
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
context.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
scheduler = await get_cache_aware_scheduler(
redis_client,
priority_mode=priority_mode,
scheduling_mode=scheduling_mode,
)
stats = await scheduler.get_stats()
logger.info("缓存统计信息查询成功")
context.add_audit_metadata(
@@ -270,7 +310,22 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
# 读取系统配置,确保监控接口与编排器使用一致的模式
priority_mode = SystemConfigService.get_config(
context.db,
"provider_priority_mode",
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
)
scheduling_mode = SystemConfigService.get_config(
context.db,
"scheduling_mode",
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
)
scheduler = await get_cache_aware_scheduler(
redis_client,
priority_mode=priority_mode,
scheduling_mode=scheduling_mode,
)
stats = await scheduler.get_stats()
payload = self._format_prometheus(stats)
context.add_audit_metadata(
@@ -624,6 +679,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
"key_name": key.name if key else None,
"key_prefix": provider_key_masked,
"rate_multiplier": key.rate_multiplier if key else 1.0,
"global_model_id": affinity.get("model_name"), # 原始的 global_model_id
"model_name": (
global_model_map.get(affinity.get("model_name")).name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
@@ -786,6 +842,65 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearSingleAffinityAdapter(AdminApiAdapter):
affinity_key: str
endpoint_id: str
model_id: str
api_format: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 直接获取指定的亲和性记录(无需遍历全部)
existing_affinity = await affinity_mgr.get_affinity(
self.affinity_key, self.api_format, self.model_id
)
if not existing_affinity:
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
# 验证 endpoint_id 是否匹配
if existing_affinity.endpoint_id != self.endpoint_id:
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
# 失效单条记录
await affinity_mgr.invalidate_affinity(
self.affinity_key, self.api_format, self.model_id, endpoint_id=self.endpoint_id
)
# 获取用于日志的信息
api_key = db.query(ApiKey).filter(ApiKey.id == self.affinity_key).first()
api_key_name = api_key.name if api_key else None
logger.info(
f"已清除单条缓存亲和性: affinity_key={self.affinity_key[:8]}..., "
f"endpoint_id={self.endpoint_id[:8]}..., model_id={self.model_id[:8]}..."
)
context.add_audit_metadata(
action="cache_clear_single",
affinity_key=self.affinity_key,
endpoint_id=self.endpoint_id,
model_id=self.model_id,
)
return {
"status": "ok",
"message": f"已清除缓存亲和性: {api_key_name or self.affinity_key[:8]}",
"affinity_key": self.affinity_key,
"endpoint_id": self.endpoint_id,
"model_id": self.model_id,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除单条缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class RequestTraceResponse(BaseModel):

View File

@@ -1,46 +1,28 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
用于查询提供商的模型列表等信息
"""
from datetime import datetime
import asyncio
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
import httpx
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.models.database import Provider, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
async def _fetch_openai_models(
client: httpx.AsyncClient,
base_url: str,
api_key: str,
api_format: str,
extra_headers: Optional[dict] = None,
) -> tuple[list, Optional[str]]:
"""获取 OpenAI 格式的模型列表
Returns:
适配器列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
registry = get_query_registry()
adapters = registry.list_adapters()
headers = {"Authorization": f"Bearer {api_key}"}
if extra_headers:
# 防止 extra_headers 覆盖 Authorization
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
headers.update(safe_headers)
return {"success": True, "data": adapters}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
# 记录详细的错误信息
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
async def _fetch_claude_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Claude 格式的模型列表
Returns:
支持的查询能力列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
headers = {
"x-api-key": api_key,
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg
Args:
request: 查询请求
async def _fetch_gemini_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Gemini 格式的模型列表
Returns:
余额信息
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
try:
response = await client.get(models_url)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": api_format,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
Args:
request: 查询请求
Returns:
模型列表
所有端点的模型列表(合并)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break
if api_key_value:
if endpoint_configs:
break
if not api_key_value:
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not endpoint.is_active or not endpoint.api_keys:
continue
if not api_key_value:
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config["extra_headers"]
try:
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
elif api_format in ["GEMINI", "GEMINI_CLI"]:
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
else:
return await _fetch_openai_models(
client, base_url, api_key_value, api_format, extra_headers
)
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
# 按 model id 去重(保留第一个)
seen_ids: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
if model_id and model_id not in seen_ids:
seen_ids.add(model_id)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
return {
"success": query_result.success,
"data": query_result.to_dict(),
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.models_service import invalidate_models_list_cache
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
@@ -419,4 +420,8 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
)
# 清除 /v1/models 列表缓存
if success:
await invalidate_models_list_cache()
return BatchAssignModelsToProviderResponse(success=success, errors=errors)

View File

@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config/export")
async def export_config(request: Request, db: Session = Depends(get_db)):
"""导出提供商和模型配置(管理员)"""
adapter = AdminExportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/config/import")
async def import_config(request: Request, db: Session = Depends(get_db)):
"""导入提供商和模型配置(管理员)"""
adapter = AdminImportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/users/export")
async def export_users(request: Request, db: Session = Depends(get_db)):
"""导出用户数据(管理员)"""
adapter = AdminExportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/users/import")
async def import_users(request: Request, db: Session = Depends(get_db)):
"""导入用户数据(管理员)"""
adapter = AdminImportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
)
return {"formats": formats}
class AdminExportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出提供商和模型配置(解密数据)"""
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
db = context.db
# 导出 GlobalModels
global_models = db.query(GlobalModel).all()
global_models_data = []
for gm in global_models:
global_models_data.append(
{
"name": gm.name,
"display_name": gm.display_name,
"default_price_per_request": gm.default_price_per_request,
"default_tiered_pricing": gm.default_tiered_pricing,
"supported_capabilities": gm.supported_capabilities,
"config": gm.config,
"is_active": gm.is_active,
}
)
# 导出 Providers 及其关联数据
providers = db.query(Provider).all()
providers_data = []
for provider in providers:
# 导出 Endpoints
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == provider.id)
.all()
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
}
)
# 导出 Provider Models
models = db.query(Model).filter(Model.provider_id == provider.id).all()
models_data = []
for model in models:
# 获取关联的 GlobalModel 名称
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
)
models_data.append(
{
"global_model_name": global_model.name if global_model else None,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": model.provider_model_aliases,
"price_per_request": model.price_per_request,
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": model.supports_image_generation,
"is_active": model.is_active,
"config": model.config,
}
)
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"rate_limit": provider.rate_limit,
"concurrent_limit": provider.concurrent_limit,
"config": provider.config,
"endpoints": endpoints_data,
"models": models_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
}
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
class AdminImportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入提供商和模型配置"""
import uuid
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.core.enums import ProviderBillingType
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
global_models_data = payload.get("global_models", [])
providers_data = payload.get("providers", [])
stats = {
"global_models": {"created": 0, "updated": 0, "skipped": 0},
"providers": {"created": 0, "updated": 0, "skipped": 0},
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
"keys": {"created": 0, "updated": 0, "skipped": 0},
"models": {"created": 0, "updated": 0, "skipped": 0},
"errors": [],
}
try:
# 导入 GlobalModels
global_model_map = {} # name -> id 映射
for gm_data in global_models_data:
existing = (
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
)
if existing:
global_model_map[gm_data["name"]] = existing.id
if merge_mode == "skip":
stats["global_models"]["skipped"] += 1
continue
elif merge_mode == "error":
raise InvalidRequestException(
f"GlobalModel '{gm_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing.display_name = gm_data.get(
"display_name", existing.display_name
)
existing.default_price_per_request = gm_data.get(
"default_price_per_request"
)
existing.default_tiered_pricing = gm_data.get(
"default_tiered_pricing", existing.default_tiered_pricing
)
existing.supported_capabilities = gm_data.get(
"supported_capabilities"
)
existing.config = gm_data.get("config")
existing.is_active = gm_data.get("is_active", True)
existing.updated_at = datetime.now(timezone.utc)
stats["global_models"]["updated"] += 1
else:
# 创建新记录
new_gm = GlobalModel(
id=str(uuid.uuid4()),
name=gm_data["name"],
display_name=gm_data.get("display_name", gm_data["name"]),
default_price_per_request=gm_data.get("default_price_per_request"),
default_tiered_pricing=gm_data.get(
"default_tiered_pricing",
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
),
supported_capabilities=gm_data.get("supported_capabilities"),
config=gm_data.get("config"),
is_active=gm_data.get("is_active", True),
)
db.add(new_gm)
db.flush()
global_model_map[gm_data["name"]] = new_gm.id
stats["global_models"]["created"] += 1
# 导入 Providers
for prov_data in providers_data:
existing_provider = (
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
)
if existing_provider:
provider_id = existing_provider.id
if merge_mode == "skip":
stats["providers"]["skipped"] += 1
# 仍然需要处理 endpoints 和 models如果存在
elif merge_mode == "error":
raise InvalidRequestException(
f"Provider '{prov_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
if prov_data.get("billing_type"):
existing_provider.billing_type = ProviderBillingType(
prov_data["billing_type"]
)
existing_provider.monthly_quota_usd = prov_data.get(
"monthly_quota_usd"
)
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
existing_provider.is_active = prov_data.get("is_active", True)
existing_provider.rate_limit = prov_data.get("rate_limit")
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
else:
# 创建新 Provider
billing_type = ProviderBillingType.PAY_AS_YOU_GO
if prov_data.get("billing_type"):
billing_type = ProviderBillingType(prov_data["billing_type"])
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
rate_limit=prov_data.get("rate_limit"),
concurrent_limit=prov_data.get("concurrent_limit"),
config=prov_data.get("config"),
)
db.add(new_provider)
db.flush()
provider_id = new_provider.id
stats["providers"]["created"] += 1
# 导入 Endpoints
for ep_data in prov_data.get("endpoints", []):
existing_ep = (
db.query(ProviderEndpoint)
.filter(
ProviderEndpoint.provider_id == provider_id,
ProviderEndpoint.api_format == ep_data["api_format"],
)
.first()
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_ep.base_url = ep_data.get(
"base_url", existing_ep.base_url
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
new_ep = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_format=ep_data["api_format"],
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
stats["keys"]["created"] += 1
# 导入 Models
for model_data in prov_data.get("models", []):
global_model_name = model_data.get("global_model_name")
if not global_model_name:
stats["errors"].append(
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
)
continue
global_model_id = global_model_map.get(global_model_name)
if not global_model_id:
# 尝试从数据库查找
existing_gm = (
db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name)
.first()
)
if existing_gm:
global_model_id = existing_gm.id
else:
stats["errors"].append(
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
)
continue
existing_model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.provider_model_name == model_data["provider_model_name"],
)
.first()
)
if existing_model:
if merge_mode == "skip":
stats["models"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_model.global_model_id = global_model_id
existing_model.provider_model_aliases = model_data.get(
"provider_model_aliases"
)
existing_model.price_per_request = model_data.get(
"price_per_request"
)
existing_model.tiered_pricing = model_data.get(
"tiered_pricing"
)
existing_model.supports_vision = model_data.get(
"supports_vision"
)
existing_model.supports_function_calling = model_data.get(
"supports_function_calling"
)
existing_model.supports_streaming = model_data.get(
"supports_streaming"
)
existing_model.supports_extended_thinking = model_data.get(
"supports_extended_thinking"
)
existing_model.supports_image_generation = model_data.get(
"supports_image_generation"
)
existing_model.is_active = model_data.get("is_active", True)
existing_model.config = model_data.get("config")
existing_model.updated_at = datetime.now(timezone.utc)
stats["models"]["updated"] += 1
else:
new_model = Model(
id=str(uuid.uuid4()),
provider_id=provider_id,
global_model_id=global_model_id,
provider_model_name=model_data["provider_model_name"],
provider_model_aliases=model_data.get(
"provider_model_aliases"
),
price_per_request=model_data.get("price_per_request"),
tiered_pricing=model_data.get("tiered_pricing"),
supports_vision=model_data.get("supports_vision"),
supports_function_calling=model_data.get(
"supports_function_calling"
),
supports_streaming=model_data.get("supports_streaming"),
supports_extended_thinking=model_data.get(
"supports_extended_thinking"
),
supports_image_generation=model_data.get(
"supports_image_generation"
),
is_active=model_data.get("is_active", True),
config=model_data.get("config"),
)
db.add(new_model)
stats["models"]["created"] += 1
db.commit()
# 失效缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.clear_all_caches()
return {
"message": "配置导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")
class AdminExportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出用户数据(保留加密数据,排除管理员)"""
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
db = context.db
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
users_data.append(
{
"email": user.email,
"username": user.username,
"password_hash": user.password_hash,
"role": user.role.value if user.role else "user",
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"model_capability_settings": user.model_capability_settings,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
"is_active": user.is_active,
"api_keys": api_keys_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
}
class AdminImportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入用户数据"""
import uuid
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"errors": [],
}
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
role_str = str(user_data.get("role", "")).lower()
if role_str == "admin":
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
stats["users"]["skipped"] += 1
continue
existing_user = (
db.query(User).filter(User.email == user_data["email"]).first()
)
if existing_user:
user_id = existing_user.id
if merge_mode == "skip":
stats["users"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"用户 '{user_data['email']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有用户
existing_user.username = user_data.get(
"username", existing_user.username
)
if user_data.get("password_hash"):
existing_user.password_hash = user_data["password_hash"]
if user_data.get("role"):
existing_user.role = UserRole(user_data["role"])
existing_user.allowed_providers = user_data.get("allowed_providers")
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
existing_user.allowed_models = user_data.get("allowed_models")
existing_user.model_capability_settings = user_data.get(
"model_capability_settings"
)
existing_user.quota_usd = user_data.get("quota_usd")
existing_user.used_usd = user_data.get("used_usd", 0.0)
existing_user.total_usd = user_data.get("total_usd", 0.0)
existing_user.is_active = user_data.get("is_active", True)
existing_user.updated_at = datetime.now(timezone.utc)
stats["users"]["updated"] += 1
else:
# 创建新用户
role = UserRole.USER
if user_data.get("role"):
role = UserRole(user_data["role"])
new_user = User(
id=str(uuid.uuid4()),
email=user_data["email"],
username=user_data.get("username", user_data["email"].split("@")[0]),
password_hash=user_data.get("password_hash", ""),
role=role,
allowed_providers=user_data.get("allowed_providers"),
allowed_endpoints=user_data.get("allowed_endpoints"),
allowed_models=user_data.get("allowed_models"),
model_capability_settings=user_data.get("model_capability_settings"),
quota_usd=user_data.get("quota_usd"),
used_usd=user_data.get("used_usd", 0.0),
total_usd=user_data.get("total_usd", 0.0),
is_active=user_data.get("is_active", True),
)
db.add(new_user)
db.flush()
user_id = new_user.id
stats["users"]["created"] += 1
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
db.add(new_key)
stats["api_keys"]["created"] += 1
db.commit()
return {
"message": "用户数据导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")

View File

@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),

View File

@@ -140,9 +140,9 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
if not authorization or not authorization.lower().startswith("bearer "):
return None
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:
return None

View File

@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
from ..models.database import SystemConfig
from src.models.database import SystemConfig
db = context.db
payload = context.ensure_json_body()

View File

@@ -55,6 +55,23 @@ async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"])
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
async def invalidate_models_list_cache() -> None:
"""
清除所有 /v1/models 列表缓存
在模型创建、更新、删除时调用,确保模型列表实时更新
"""
# 清除所有格式的缓存
all_formats = ["CLAUDE", "OPENAI", "GEMINI"]
for fmt in all_formats:
cache_key = f"{_CACHE_KEY_PREFIX}:{fmt}"
try:
await CacheService.delete(cache_key)
logger.debug(f"[ModelsService] 已清除缓存: {cache_key}")
except Exception as e:
logger.warning(f"[ModelsService] 清除缓存失败 {cache_key}: {e}")
@dataclass
class ModelInfo:
"""统一的模型信息结构"""
@@ -65,6 +82,21 @@ class ModelInfo:
created_at: Optional[str] # ISO 格式
created_timestamp: int # Unix 时间戳
provider_name: str
# 能力配置
streaming: bool = True
vision: bool = False
function_calling: bool = False
extended_thinking: bool = False
image_generation: bool = False
structured_output: bool = False
# 规格参数
context_limit: Optional[int] = None
output_limit: Optional[int] = None
# 元信息
family: Optional[str] = None
knowledge_cutoff: Optional[str] = None
input_modalities: Optional[list[str]] = None
output_modalities: Optional[list[str]] = None
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
@@ -181,13 +213,19 @@ def _extract_model_info(model: Any) -> ModelInfo:
global_model = model.global_model
model_id: str = global_model.name if global_model else model.provider_model_name
display_name: str = global_model.display_name if global_model else model.provider_model_name
description: Optional[str] = global_model.description if global_model else None
created_at: Optional[str] = (
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
)
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
provider_name: str = model.provider.name if model.provider else "unknown"
# 从 GlobalModel.config 提取配置信息
config: dict = {}
description: Optional[str] = None
if global_model:
config = global_model.config or {}
description = config.get("description")
return ModelInfo(
id=model_id,
display_name=display_name,
@@ -195,6 +233,21 @@ def _extract_model_info(model: Any) -> ModelInfo:
created_at=created_at,
created_timestamp=created_timestamp,
provider_name=provider_name,
# 能力配置
streaming=config.get("streaming", True),
vision=config.get("vision", False),
function_calling=config.get("function_calling", False),
extended_thinking=config.get("extended_thinking", False),
image_generation=config.get("image_generation", False),
structured_output=config.get("structured_output", False),
# 规格参数
context_limit=config.get("context_limit"),
output_limit=config.get("output_limit"),
# 元信息
family=config.get("family"),
knowledge_cutoff=config.get("knowledge_cutoff"),
input_modalities=config.get("input_modalities"),
output_modalities=config.get("output_modalities"),
)

View File

@@ -5,13 +5,12 @@ from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
@@ -178,9 +177,9 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少管理员凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -205,9 +204,9 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少用户凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
"""记录审计事件
事务策略:复用请求级 Session不单独提交。
审计记录随主事务一起提交,由中间件统一管理。
"""
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
if context.db is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
# 复用请求级 Session不创建新的连接
# 审计记录随主事务一起提交,由中间件统一管理
self.audit_service.log_event(
db=audit_session,
db=context.db,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
# 审计失败不应影响主请求,仅记录警告
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,

View File

@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 需要转回业务时区再取日期,才能与日期序列匹配
def _to_business_date_str(value: datetime) -> str:
if value.tzinfo is None:
value_utc = value.replace(tzinfo=timezone.utc)
else:
value_utc = value.astimezone(timezone.utc)
return value_utc.astimezone(app_tz).date().isoformat()
stats_map = {
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
_to_business_date_str(stat.date): {
"requests": stat.total_requests,
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
"cost": stat.total_cost,
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
"unique_providers": today_unique_providers,
"fallback_count": today_fallback_count,
}
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
yesterday_date = today_local.date() - timedelta(days=1)
historical_end = min(end_date_local.date(), yesterday_date)
missing_dates: list[str] = []
cursor = start_date_local.date()
while cursor <= historical_end:
date_str = cursor.isoformat()
if date_str not in stats_map:
missing_dates.append(date_str)
cursor += timedelta(days=1)
if missing_dates:
for date_str in missing_dates[-7:]:
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
stats_map[date_str] = {
"requests": computed["total_requests"],
"tokens": (
computed["input_tokens"]
+ computed["output_tokens"]
+ computed["cache_creation_tokens"]
+ computed["cache_read_tokens"]
),
"cost": computed["total_cost"],
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
if computed["avg_response_time_ms"]
else 0,
"unique_models": computed["unique_models"],
"unique_providers": computed["unique_providers"],
"fallback_count": computed["fallback_count"],
}
else:
# 普通用户:仍需实时查询(用户级预聚合可选)
query = db.query(Usage).filter(

View File

@@ -28,7 +28,7 @@
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
@@ -100,6 +103,8 @@ class MessageTelemetry:
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 时间指标
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
@@ -133,6 +138,7 @@ class MessageTelemetry:
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
@@ -396,6 +402,41 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""
import asyncio
from src.database.database import get_db
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
async def _do_update() -> None:
try:
db_gen = get_db()
db = next(db_gen)
try:
UsageService.update_usage_status(
db=db,
request_id=target_request_id,
status="streaming",
provider=provider,
target_model=target_model,
)
finally:
db.close()
except Exception as e:
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
@@ -408,9 +449,10 @@ class BaseMessageHandler:
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
UpstreamClientException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:

View File

@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)

View File

@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.exceptions import (
EmbeddedErrorException,
@@ -87,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
response_normalizer: Optional[Any] = None,
enable_response_normalization: bool = False,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
@@ -105,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
self._parser: Optional[ResponseParser] = None
self._request_builder = PassthroughRequestBuilder()
self.response_normalizer = response_normalizer
self.enable_response_normalization = enable_response_normalization
@property
def parser(self) -> ResponseParser:
@@ -265,8 +262,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name
@@ -293,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建更新状态的回调闭包(可以访问 ctx
def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
on_streaming_start=update_streaming_status,
)
# 定义请求函数
@@ -365,7 +369,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
ctx,
original_headers,
original_request_body,
self.elapsed_ms(),
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
)
# 创建监控流
@@ -378,6 +382,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
@@ -461,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
pool=config.http_pool_timeout,
)
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -473,12 +484,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
stream_response.raise_for_status()
# 创建行迭代器
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
byte_iterator = stream_response.aiter_raw()
# 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator,
prefetched_chunks = await stream_processor.prefetch_and_check_error(
byte_iterator,
provider,
endpoint,
ctx,
@@ -503,13 +515,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 创建流生成器
# 创建流生成器(传入字节流迭代器)
return stream_processor.create_response_stream(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
start_time=self.start_time,
)
async def _record_stream_failure(
@@ -626,11 +639,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
f"模型={model} -> {mapped_model or '无映射'}")
logger.debug(f" [{self.request_id}] 请求URL: {url}")
logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code
@@ -645,10 +664,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
response_headers=response_headers,
)
elif resp.status_code >= 500:
raise ProviderNotAvailableException(f"提供商服务不可用: {provider.name}")
elif resp.status_code != 200:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}")
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}"
f"提供商服务不可用: {provider.name}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}")
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
response_json = resp.json()

View File

@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
"""
import asyncio
import codecs
import json
import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
)
import httpx
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler):
"""
CLI 格式消息处理器基类
@@ -212,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name
@@ -409,6 +355,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
@@ -433,7 +380,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.chunk_count = 0
ctx.data_count = 0
ctx.has_completion = False
ctx.collected_text = ""
ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0
ctx.output_tokens = 0
ctx.cached_tokens = 0
@@ -507,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -521,12 +474,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题
byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
prefetched_chunks = await self._prefetch_and_check_embedded_error(
byte_iterator, provider, endpoint, ctx
)
except httpx.HTTPStatusError as e:
@@ -551,10 +504,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
)
async def _create_response_stream(
@@ -564,58 +517,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
"""创建响应流生成器(使用字节流)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines():
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
yield b"\n"
continue
continue
ctx.chunk_count += 1
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
@@ -689,7 +659,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -703,20 +673,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流上下文
Returns:
预读的列表(需要在后续流中先输出)
预读的字节块列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
"""
prefetched_lines: list = []
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
# 获取对应格式的解析器
@@ -729,69 +704,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
else:
provider_parser = self.parser
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
# 解析数据
normalized_line = line.rstrip("\r")
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
line_count += 1
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
@@ -800,112 +792,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def _create_response_stream_with_prefetch(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
"""创建响应流生成器(带预读数据,使用字节流"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming(ctx.request_id)
if prefetched_chunks:
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
async for chunk in byte_iterator:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
flushed_events = sse_parser.flush()
@@ -1034,7 +1082,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取文本内容
text = self.parser.extract_text_content(data)
if text:
ctx.collected_text += text
ctx.append_text(text)
# 检查完成事件
if event_type in ("response.completed", "message_stop"):
@@ -1066,8 +1114,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
async for chunk in stream_generator:
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
raise
except httpx.TimeoutException as e:
ctx.status_code = 504
@@ -1086,9 +1136,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
) -> None:
"""在流完成后记录统计信息"""
try:
await asyncio.sleep(0.1)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
response_time_ms = int((time.time() - ctx.start_time) * 1000)
await asyncio.sleep(0.1)
if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
@@ -1168,6 +1220,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
@@ -1188,9 +1241,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
# 简洁的请求完成摘要(两行格式)
line1 = (
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
@@ -1242,7 +1304,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any],
) -> None:
"""记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503
if isinstance(error, ProviderAuthException):
@@ -1364,10 +1427,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code

View File

@@ -0,0 +1,274 @@
"""
流式内容提取器 - 策略模式实现
为不同 API 格式OpenAI、Claude、Gemini提供内容提取和 chunk 构造的抽象。
StreamSmoother 使用这些提取器来处理不同格式的 SSE 事件。
"""
import copy
import json
from abc import ABC, abstractmethod
from typing import Optional
class ContentExtractor(ABC):
"""
流式内容提取器抽象基类
定义从 SSE 事件中提取文本内容和构造新 chunk 的接口。
每种 API 格式OpenAI、Claude、Gemini需要实现自己的提取器。
"""
@abstractmethod
def extract_content(self, data: dict) -> Optional[str]:
"""
从 SSE 数据中提取可拆分的文本内容
Args:
data: 解析后的 JSON 数据
Returns:
提取的文本内容,如果无法提取则返回 None
"""
pass
@abstractmethod
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
"""
使用新内容构造 SSE chunk
Args:
original_data: 原始 JSON 数据
new_content: 新的文本内容
event_type: SSE 事件类型(某些格式需要)
is_first: 是否是第一个 chunk用于保留 role 等字段)
Returns:
编码后的 SSE 字节数据
"""
pass
class OpenAIContentExtractor(ContentExtractor):
"""
OpenAI 格式内容提取器
处理 OpenAI Chat Completions API 的流式响应格式:
- 数据结构: choices[0].delta.content
- 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
choices = data.get("choices")
if not isinstance(choices, list) or len(choices) != 1:
return None
first_choice = choices[0]
if not isinstance(first_choice, dict):
return None
delta = first_choice.get("delta")
if not isinstance(delta, dict):
return None
content = delta.get("content")
if not isinstance(content, str):
return None
# 只有 delta 仅包含 role/content 时才允许拆分
# 避免破坏 tool_calls、function_call 等复杂结构
allowed_keys = {"role", "content"}
if not all(key in allowed_keys for key in delta.keys()):
return None
return content
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = original_data.copy()
if "choices" in new_data and new_data["choices"]:
new_choices = []
for choice in new_data["choices"]:
new_choice = choice.copy()
if "delta" in new_choice:
new_delta = {}
# 只有第一个 chunk 保留 role
if is_first and "role" in new_choice["delta"]:
new_delta["role"] = new_choice["delta"]["role"]
new_delta["content"] = new_content
new_choice["delta"] = new_delta
new_choices.append(new_choice)
new_data["choices"] = new_choices
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
class ClaudeContentExtractor(ContentExtractor):
"""
Claude 格式内容提取器
处理 Claude Messages API 的流式响应格式:
- 事件类型: content_block_delta
- 数据结构: delta.type=text_delta, delta.text
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
# 检查事件类型
if data.get("type") != "content_block_delta":
return None
delta = data.get("delta", {})
if not isinstance(delta, dict):
return None
# 检查 delta 类型
if delta.get("type") != "text_delta":
return None
text = delta.get("text")
if not isinstance(text, str):
return None
return text
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = original_data.copy()
if "delta" in new_data:
new_delta = new_data["delta"].copy()
new_delta["text"] = new_content
new_data["delta"] = new_delta
# Claude 格式需要 event: 前缀
event_name = event_type or "content_block_delta"
return f"event: {event_name}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
"utf-8"
)
class GeminiContentExtractor(ContentExtractor):
"""
Gemini 格式内容提取器
处理 Gemini API 的流式响应格式:
- 数据结构: candidates[0].content.parts[0].text
- 只有纯文本块才拆分
"""
def extract_content(self, data: dict) -> Optional[str]:
if not isinstance(data, dict):
return None
candidates = data.get("candidates")
if not isinstance(candidates, list) or len(candidates) != 1:
return None
first_candidate = candidates[0]
if not isinstance(first_candidate, dict):
return None
content = first_candidate.get("content", {})
if not isinstance(content, dict):
return None
parts = content.get("parts", [])
if not isinstance(parts, list) or len(parts) != 1:
return None
first_part = parts[0]
if not isinstance(first_part, dict):
return None
text = first_part.get("text")
# 只有纯文本块(只有 text 字段)才拆分
if not isinstance(text, str) or len(first_part) != 1:
return None
return text
def create_chunk(
self,
original_data: dict,
new_content: str,
event_type: str = "",
is_first: bool = False,
) -> bytes:
new_data = copy.deepcopy(original_data)
if "candidates" in new_data and new_data["candidates"]:
first_candidate = new_data["candidates"][0]
if "content" in first_candidate:
content = first_candidate["content"]
if "parts" in content and content["parts"]:
content["parts"][0]["text"] = new_content
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
# 提取器注册表
_EXTRACTORS: dict[str, type[ContentExtractor]] = {
"openai": OpenAIContentExtractor,
"claude": ClaudeContentExtractor,
"gemini": GeminiContentExtractor,
}
def get_extractor(format_name: str) -> Optional[ContentExtractor]:
"""
根据格式名获取对应的内容提取器实例
Args:
format_name: 格式名称openai, claude, gemini
Returns:
对应的提取器实例,如果格式不支持则返回 None
"""
extractor_class = _EXTRACTORS.get(format_name.lower())
if extractor_class:
return extractor_class()
return None
def register_extractor(format_name: str, extractor_class: type[ContentExtractor]) -> None:
"""
注册新的内容提取器
Args:
format_name: 格式名称
extractor_class: 提取器类
"""
_EXTRACTORS[format_name.lower()] = extractor_class
def get_extractor_formats() -> list[str]:
"""
获取所有已注册的格式名称列表
Returns:
格式名称列表
"""
return list(_EXTRACTORS.keys())

View File

@@ -8,6 +8,7 @@
- 请求/响应数据
"""
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@@ -25,12 +26,18 @@ class StreamContext:
model: str
api_format: str
# 请求标识信息CLI handler 需要)
request_id: str = ""
user_id: int = 0
api_key_id: int = 0
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
@@ -43,7 +50,14 @@ class StreamContext:
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
# 时间指标
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
start_time: float = field(default_factory=time.time)
# 响应状态
status_code: int = 200
@@ -55,6 +69,12 @@ class StreamContext:
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 格式转换信息CLI handler 需要)
client_api_format: str = ""
# Provider 响应元数据CLI handler 需要)
response_metadata: Dict[str, Any] = field(default_factory=dict)
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
@@ -71,16 +91,30 @@ class StreamContext:
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self._collected_text_parts = []
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
self.response_id = None
self.final_usage = None
self.final_response = None
@property
def collected_text(self) -> str:
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
return "".join(self._collected_text_parts)
def append_text(self, text: str) -> None:
"""追加文本内容(仅在需要收集文本时调用)"""
if text:
self._collected_text_parts.append(text)
def update_provider_info(
self,
@@ -145,6 +179,19 @@ class StreamContext:
self.status_code = status_code
self.error_message = error_message
def record_first_byte_time(self, start_time: float) -> None:
"""
记录首字时间 (TTFB - Time To First Byte)
应在第一次向客户端发送数据时调用。
如果已记录过,则不会覆盖(避免重试时重复记录)。
Args:
start_time: 请求开始时间 (time.time())
"""
if self.first_byte_time_ms is None:
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
@@ -171,10 +218,22 @@ class StreamContext:
获取日志摘要
用于请求完成/失败时的日志输出。
包含首字时间 (TTFB) 和总响应时间,分两行显示。
"""
status = "OK" if self.is_success() else "FAIL"
return (
# 第一行:基本信息 + 首字时间
line1 = (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"{self.provider_name or 'unknown'}"
)
if self.first_byte_time_ms is not None:
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
# 第二行:总响应时间 + tokens
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)
return f"{line1}\n{line2}"

View File

@@ -6,14 +6,22 @@
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
5. 流式平滑输出
"""
import asyncio
import codecs
import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.content_extractors import (
ContentExtractor,
get_extractor,
get_extractor_formats,
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
@@ -23,11 +31,20 @@ from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamSmoothingConfig:
"""流式平滑输出配置"""
enabled: bool = False
chunk_size: int = 20
delay_ms: int = 8
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测响应生成。
负责处理 SSE 流的解析、错误检测响应生成和平滑输出
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
@@ -36,6 +53,9 @@ class StreamProcessor:
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
smoothing_config: Optional[StreamSmoothingConfig] = None,
):
"""
初始化流处理器
@@ -44,10 +64,17 @@ class StreamProcessor:
request_id: 请求 ID用于日志
default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态)
collect_text: 是否收集文本内容
smoothing_config: 流式平滑输出配置
"""
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
# 内容提取器缓存
self._extractors: dict[str, ContentExtractor] = {}
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
@@ -112,18 +139,26 @@ class StreamProcessor:
)
# 提取文本
text = parser.extract_text_content(data)
if text:
ctx.collected_text += text
if self.collect_text:
text = parser.extract_text_content(data)
if text:
ctx.append_text(text)
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
# 检查 OpenAI 格式的 finish_reason
choices = data.get("choices", [])
if choices and isinstance(choices, list) and len(choices) > 0:
finish_reason = choices[0].get("finish_reason")
if finish_reason is not None:
ctx.has_completion = True
async def prefetch_and_check_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -136,97 +171,126 @@ class StreamProcessor:
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的列表
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
if line_count >= max_prefetch_lines:
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = line
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def create_response_stream(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None,
prefetched_chunks: Optional[list] = None,
*,
start_time: Optional[float] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
统一的流生成器,支持预读数据和不带预读数据两种情况
从字节流中解析 SSE 数据并转发,支持预读数据。
Args:
ctx: 流式上下文
line_iterator: 迭代器
byte_iterator: 字节流迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_lines: 预读的列表(可选)
prefetched_chunks: 预读的字节块列表(可选)
start_time: 请求开始时间,用于计算 TTFB可选
Yields:
编码后的响应数据块
@@ -234,25 +298,82 @@ class StreamProcessor:
try:
sse_parser = SSEEventParser()
streaming_started = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 处理预读数据
if prefetched_lines:
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for line in prefetched_lines:
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 把原始数据转发给客户端
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的流数据
async for line in line_iterator:
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 原始数据透传
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的缓冲区数据(如果有未完成的行)
if buffer:
try:
# 使用 final=True 处理最后的不完整字符
line = decoder.decode(buffer, True)
self._process_line(ctx, sse_parser, line)
except Exception as e:
logger.warning(
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
f"bytes={buffer[:50]!r}"
)
# 处理剩余事件
for event in sse_parser.flush():
@@ -268,7 +389,7 @@ class StreamProcessor:
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> list[bytes]:
) -> None:
"""
处理单行数据
@@ -276,26 +397,17 @@ class StreamProcessor:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
Returns:
要发送的数据块列表
"""
result: list[bytes] = []
normalized_line = line.rstrip("\r")
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
if normalized_line != "":
ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
async def create_monitored_stream(
self,
@@ -317,22 +429,201 @@ class StreamProcessor:
响应数据块
"""
try:
async for chunk in stream_generator:
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
# 使用后台任务检查断连,完全不阻塞流式传输
disconnected = False
async def check_disconnect_background() -> None:
nonlocal disconnected
while not disconnected and not ctx.has_completion:
await asyncio.sleep(0.5)
if await is_disconnected():
disconnected = True
break
# 启动后台检查任务
check_task = asyncio.create_task(check_disconnect_background())
try:
async for chunk in stream_generator:
if disconnected:
# 如果响应已完成,客户端断开不算失败
if ctx.has_completion:
logger.info(
f"ID:{self.request_id} | Client disconnected after completion"
)
else:
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499
ctx.error_message = "client_disconnected"
break
yield chunk
finally:
check_task.cancel()
try:
await check_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def create_smoothed_stream(
self,
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""
创建平滑输出的流生成器
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
否则直接透传原始流。
Args:
stream_generator: 原始流生成器
Yields:
平滑处理后的响应数据块
"""
if not self.smoothing_config.enabled:
# 未启用平滑输出,直接透传
async for chunk in stream_generator:
yield chunk
return
# 启用平滑输出
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
# 按双换行分割 SSE 事件(标准 SSE 格式)
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
# 解析事件块
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
# 没有 data 行,直接透传
if data_str is None:
yield event_block + b"\n\n"
continue
# [DONE] 直接透传
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
# 尝试解析 JSON
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
# 检测格式并提取内容
content, extractor = self._detect_format_and_extract(data)
# 只有内容长度大于 1 才需要平滑处理
if content and len(content) > 1 and extractor:
# 获取配置的延迟
delay_seconds = self._calculate_delay()
# 拆分内容
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
# 使用提取器创建新 chunk
sse_chunk = extractor.create_chunk(
data,
sub_content,
event_type=event_type,
is_first=is_first,
)
yield sse_chunk
# 除了最后一个块,其他块之间加延迟
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
# 不需要拆分,直接透传
yield event_block + b"\n\n"
if content:
is_first_content = False
# 处理剩余数据
if buffer:
yield buffer
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
"""获取或创建格式对应的提取器(带缓存)"""
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
"""
检测数据格式并提取内容
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
Returns:
(content, extractor): 提取的内容和对应的提取器
"""
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self) -> float:
"""获取配置的延迟(秒)"""
return self.smoothing_config.delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
"""
按块拆分文本
"""
chunk_size = self.smoothing_config.chunk_size
text_length = len(content)
if text_length <= chunk_size:
return [content]
# 按块拆分
chunks = []
for i in range(0, text_length, chunk_size):
chunks.append(content[i : i + chunk_size])
return chunks
async def _cleanup(
self,
response_ctx: Any,
@@ -347,3 +638,128 @@ class StreamProcessor:
await http_client.aclose()
except Exception:
pass
async def create_smoothed_stream(
stream_generator: AsyncGenerator[bytes, None],
chunk_size: int = 20,
delay_ms: int = 8,
) -> AsyncGenerator[bytes, None]:
"""
独立的平滑流生成函数
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
Args:
stream_generator: 原始流生成器
chunk_size: 每块字符数
delay_ms: 每块之间的延迟毫秒数
Yields:
平滑处理后的响应数据块
"""
processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
async for chunk in processor.smooth(stream_generator):
yield chunk
class _LightweightSmoother:
"""
轻量级平滑处理器
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
"""
def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
self.chunk_size = chunk_size
self.delay_ms = delay_ms
self._extractors: dict[str, ContentExtractor] = {}
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self) -> float:
return self.delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
text_length = len(content)
if text_length <= self.chunk_size:
return [content]
return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
async def smooth(
self, stream_generator: AsyncGenerator[bytes, None]
) -> AsyncGenerator[bytes, None]:
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
if data_str is None:
yield event_block + b"\n\n"
continue
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
content, extractor = self._detect_format_and_extract(data)
if content and len(content) > 1 and extractor:
delay_seconds = self._calculate_delay()
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
sse_chunk = extractor.create_chunk(
data, sub_content, event_type=event_type, is_first=is_first
)
yield sse_chunk
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
yield event_block + b"\n\n"
if content:
is_first_content = False
if buffer:
yield buffer

View File

@@ -8,6 +8,7 @@
"""
import asyncio
import time
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
response_time_ms: int,
start_time: float,
) -> None:
"""
记录流式统计信息
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒)
start_time: 请求开始时间 (time.time())
"""
bg_db = None
try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,

View File

@@ -2,7 +2,7 @@
Handler 基础工具函数
"""
from typing import Any, Dict
from typing import Any, Dict, Optional
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
@@ -22,10 +22,34 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
Returns:
缓存创建 tokens 总数
"""
# 优先使用新格式
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
total = int(cache_5m) + int(cache_1h)
# 检查新格式字段是否存在(而非值是否为 0
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
has_new_format = (
"claude_cache_creation_5_m_tokens" in usage
or "claude_cache_creation_1_h_tokens" in usage
)
# 如果新格式不存在total == 0回退到旧格式
return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0))
if has_new_format:
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
return int(cache_5m) + int(cache_1h)
# 回退到旧格式
return int(usage.get("cache_creation_input_tokens", 0))
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
构建 SSEtext/event-stream推荐响应头用于减少代理缓冲带来的卡顿/成段输出。
说明:
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
"""
headers: Dict[str, str] = {
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
if extra_headers:
headers.update(extra_headers)
return headers

View File

@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return result
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -111,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 处理消息增量(包含最终 usage
elif event_type == "message_delta":

View File

@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
TODO: 如果需要,实现响应规范化逻辑
"""
# 可选:使用 response_normalizer 进行规范化
# if (
# self.response_normalizer
# and self.response_normalizer.should_normalize(response)
# ):
# return self.response_normalizer.normalize_gemini_response(
# response_data=response,
# request_id=self.request_id,
# strict=False,
# )
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
parts = content.get("parts", [])
for part in parts:
if "text" in part:
ctx.collected_text += part["text"]
ctx.append_text(part["text"])
# 检查结束原因
finish_reason = candidate.get("finishReason")

View File

@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_openai_response(
response_data=response,
request_id=self.request_id,
strict=False,
)
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta")
if isinstance(delta, str):
ctx.collected_text += delta
ctx.append_text(delta)
elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"]
ctx.append_text(delta["text"])
# 处理完成事件
elif event_type == "response.completed":
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 备用:从顶层 usage 提取
usage_obj = data.get("usage")

View File

@@ -210,9 +210,9 @@ class PublicModelsAdapter(PublicApiAdapter):
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
description=global_model.config.get("description") if global_model and global_model.config else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
@@ -274,7 +274,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
Model.provider_model_name.ilike(f"%{self.query}%")
| GlobalModel.name.ilike(f"%{self.query}%")
| GlobalModel.display_name.ilike(f"%{self.query}%")
| GlobalModel.description.ilike(f"%{self.query}%")
)
query_stmt = query_stmt.filter(search_filter)
if self.provider_id is not None:
@@ -293,9 +292,9 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
description=global_model.config.get("description") if global_model and global_model.config else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
@@ -499,7 +498,6 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
or_(
GlobalModel.name.ilike(search_term),
GlobalModel.display_name.ilike(search_term),
GlobalModel.description.ilike(search_term),
)
)
@@ -517,21 +515,11 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
id=gm.id,
name=gm.name,
display_name=gm.display_name,
description=gm.description,
icon_url=gm.icon_url,
is_active=gm.is_active,
default_price_per_request=gm.default_price_per_request,
default_tiered_pricing=gm.default_tiered_pricing,
default_supports_vision=gm.default_supports_vision or False,
default_supports_function_calling=gm.default_supports_function_calling or False,
default_supports_streaming=(
gm.default_supports_streaming
if gm.default_supports_streaming is not None
else True
),
default_supports_extended_thinking=gm.default_supports_extended_thinking
or False,
supported_capabilities=gm.supported_capabilities,
config=gm.config,
)
)

View File

@@ -251,8 +251,8 @@ def _build_gemini_list_response(
"version": "001",
"displayName": m.display_name,
"description": m.description or f"Model {m.id}",
"inputTokenLimit": 128000,
"outputTokenLimit": 8192,
"inputTokenLimit": m.context_limit if m.context_limit is not None else 128000,
"outputTokenLimit": m.output_limit if m.output_limit is not None else 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"],
"temperature": 1.0,
"maxTemperature": 2.0,
@@ -297,8 +297,8 @@ def _build_gemini_model_response(model_info: ModelInfo) -> dict:
"version": "001",
"displayName": model_info.display_name,
"description": model_info.description or f"Model {model_info.id}",
"inputTokenLimit": 128000,
"outputTokenLimit": 8192,
"inputTokenLimit": model_info.context_limit if model_info.context_limit is not None else 128000,
"outputTokenLimit": model_info.output_limit if model_info.output_limit is not None else 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"],
"temperature": 1.0,
"maxTemperature": 2.0,

View File

@@ -5,12 +5,55 @@
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from urllib.parse import quote, urlparse
import httpx
from src.core.logger import logger
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
"""
根据代理配置构建完整的代理 URL
Args:
proxy_config: 代理配置字典,包含 url, username, password, enabled
Returns:
完整的代理 URL如 socks5://user:pass@host:port
如果 enabled=False 或无配置,返回 None
"""
if not proxy_config:
return None
# 检查 enabled 字段,默认为 True兼容旧数据
if not proxy_config.get("enabled", True):
return None
proxy_url = proxy_config.get("url")
if not proxy_url:
return None
username = proxy_config.get("username")
password = proxy_config.get("password")
# 只要有用户名就添加认证信息(密码可以为空)
if username:
parsed = urlparse(proxy_url)
# URL 编码用户名和密码,处理特殊字符(如 @, :, /
encoded_username = quote(username, safe="")
encoded_password = quote(password, safe="") if password else ""
# 重新构建带认证的代理 URL
if encoded_password:
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
else:
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
if parsed.path:
auth_proxy += parsed.path
return auth_proxy
return proxy_url
class HTTPClientPool:
"""
@@ -121,6 +164,44 @@ class HTTPClientPool:
finally:
await client.aclose()
@classmethod
def create_client_with_proxy(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: Optional[httpx.Timeout] = None,
**kwargs: Any,
) -> httpx.AsyncClient:
"""
创建带代理配置的HTTP客户端
Args:
proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数
Returns:
配置好的 httpx.AsyncClient 实例
"""
config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
}
if timeout:
config["timeout"] = timeout
else:
config["timeout"] = httpx.Timeout(10.0, read=300.0)
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs)
return httpx.AsyncClient(**config)
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:

View File

@@ -120,7 +120,7 @@ class RedisClientManager:
if self._circuit_open_until and time.time() < self._circuit_open_until:
remaining = self._circuit_open_until - time.time()
logger.warning(
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
remaining,
self._last_error,
)
@@ -200,7 +200,7 @@ class RedisClientManager:
if self._consecutive_failures >= self._circuit_threshold:
self._circuit_open_until = time.time() + self._circuit_reset_seconds
logger.warning(
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
self._consecutive_failures,
@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
if _redis_manager is None:
_redis_manager = RedisClientManager()
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
if _redis_manager.get_client() is None:
await _redis_manager.initialize(require_redis=require_redis)
return _redis_manager.get_client()

View File

@@ -41,8 +41,8 @@ class CacheSize:
class ConcurrencyDefaults:
"""并发控制默认值"""
# 自适应并发初始限制(保守值
INITIAL_LIMIT = 3
# 自适应并发初始限制(宽松起步,遇到 429 再降低
INITIAL_LIMIT = 50
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制
COOLDOWN_AFTER_429_MINUTES = 5
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 1
INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时的缩容比例
DECREASE_MULTIPLIER = 0.7
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 100
MAX_CONCURRENT_LIMIT = 200
# 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10
# === 缓存用户预留比例 ===
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
# 0.1 表示缓存用户预留 10%,新用户可用 90%
CACHE_RESERVATION_RATIO = 0.1
class CircuitBreakerDefaults:
"""熔断器配置默认值(滑动窗口 + 半开状态模式)

View File

@@ -105,6 +105,13 @@ class Config:
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
# 可信代理配置
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1即信任最近一层代理
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
# 异常处理配置
# 设置为 True 时ProxyException 会传播到路由层以便记录 provider_request_headers
# 设置为 False 时,使用全局异常处理器统一处理
@@ -122,9 +129,9 @@ class Config:
# 并发控制配置
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL防止死锁
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))

View File

@@ -46,6 +46,11 @@ class BatchCommitter:
def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改"""
# 请求级事务由中间件统一 commit/rollback避免后台任务在请求中途误提交。
if session is None:
return
if session.info.get("managed_by_middleware"):
return
self._pending_sessions.add(session)
async def _batch_commit_loop(self):

View File

@@ -1,168 +0,0 @@
"""
统一的请求上下文
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
这确保了数据在各层之间传递时不会丢失。
使用方式:
1. Pipeline 层创建 RequestContext
2. 各层通过 context 访问和更新信息
3. Adapter 层使用 context 记录 Usage
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class RequestContext:
"""
请求上下文 - 贯穿整个请求生命周期
设计原则:
1. 在请求开始时创建,包含所有已知信息
2. 在请求执行过程中逐步填充 Provider 信息
3. 在请求结束时用于记录 Usage
"""
# ==================== 请求标识 ====================
request_id: str
# ==================== 认证信息 ====================
user: Any # User model
api_key: Any # ApiKey model
db: Any # Database session
# ==================== 请求信息 ====================
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
model: str # 用户请求的模型名
is_stream: bool = False
# ==================== 原始请求 ====================
original_headers: Dict[str, str] = field(default_factory=dict)
original_body: Dict[str, Any] = field(default_factory=dict)
# ==================== 客户端信息 ====================
client_ip: str = "unknown"
user_agent: str = ""
# ==================== 计时 ====================
start_time: float = field(default_factory=time.time)
# ==================== Provider 信息(请求执行后填充)====================
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# ==================== 模型映射信息 ====================
resolved_model: Optional[str] = None # 映射后的模型名
original_model: Optional[str] = None # 原始模型名(用于价格计算)
# ==================== 请求/响应头 ====================
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# ==================== 追踪信息 ====================
attempt_id: Optional[str] = None
# ==================== 能力需求 ====================
capability_requirements: Dict[str, bool] = field(default_factory=dict)
# 运行时计算的能力需求,来源于:
# 1. 用户 model_capability_settings
# 2. 用户 ApiKey.force_capabilities
# 3. 请求头 X-Require-Capability
# 4. 失败重试时动态添加
@classmethod
def create(
cls,
*,
db: Any,
user: Any,
api_key: Any,
api_format: str,
model: str,
is_stream: bool = False,
original_headers: Optional[Dict[str, str]] = None,
original_body: Optional[Dict[str, Any]] = None,
client_ip: str = "unknown",
user_agent: str = "",
request_id: Optional[str] = None,
) -> "RequestContext":
"""创建请求上下文"""
return cls(
request_id=request_id or str(uuid.uuid4()),
db=db,
user=user,
api_key=api_key,
api_format=api_format,
model=model,
is_stream=is_stream,
original_headers=original_headers or {},
original_body=original_body or {},
client_ip=client_ip,
user_agent=user_agent,
original_model=model, # 初始时原始模型等于请求模型
)
def update_provider_info(
self,
*,
provider_name: str,
provider_id: str,
endpoint_id: str,
provider_api_key_id: str,
resolved_model: Optional[str] = None,
) -> None:
"""更新 Provider 信息(请求执行后调用)"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.provider_api_key_id = provider_api_key_id
if resolved_model:
self.resolved_model = resolved_model
def update_headers(
self,
*,
request_headers: Optional[Dict[str, str]] = None,
response_headers: Optional[Dict[str, str]] = None,
) -> None:
"""更新请求/响应头"""
if request_headers:
self.provider_request_headers = request_headers
if response_headers:
self.provider_response_headers = response_headers
@property
def elapsed_ms(self) -> int:
"""计算已经过的时间(毫秒)"""
return int((time.time() - self.start_time) * 1000)
@property
def effective_model(self) -> str:
"""获取有效的模型名(映射后优先)"""
return self.resolved_model or self.model
@property
def billing_model(self) -> str:
"""获取计费模型名(原始模型优先)"""
return self.original_model or self.model
def to_metadata_dict(self) -> Dict[str, Any]:
"""转换为元数据字典(用于 Usage 记录)"""
return {
"api_format": self.api_format,
"provider": self.provider_name or "unknown",
"model": self.effective_model,
"original_model": self.billing_model,
"provider_id": self.provider_id,
"provider_endpoint_id": self.endpoint_id,
"provider_api_key_id": self.provider_api_key_id,
"provider_request_headers": self.provider_request_headers,
"provider_response_headers": self.provider_response_headers,
"attempt_id": self.attempt_id,
}

View File

@@ -10,8 +10,8 @@ class APIFormat(Enum):
"""API格式枚举 - 决定请求/响应的处理方式"""
CLAUDE = "CLAUDE" # Claude API 格式
OPENAI = "OPENAI" # OpenAI API 格式
CLAUDE_CLI = "CLAUDE_CLI" # Claude CLI API 格式(使用 authorization: Bearer
OPENAI = "OPENAI" # OpenAI API 格式
OPENAI_CLI = "OPENAI_CLI" # OpenAI CLI/Responses API 格式(用于 Claude Code 等客户端)
GEMINI = "GEMINI" # Google Gemini API 格式
GEMINI_CLI = "GEMINI_CLI" # Gemini CLI API 格式

View File

@@ -188,12 +188,16 @@ class ProviderNotAvailableException(ProviderException):
message: str,
provider_name: Optional[str] = None,
request_metadata: Optional[Any] = None,
upstream_status: Optional[int] = None,
upstream_response: Optional[str] = None,
):
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
)
self.upstream_status = upstream_status
self.upstream_response = upstream_response
class ProviderTimeoutException(ProviderException):
@@ -442,6 +446,36 @@ class EmbeddedErrorException(ProviderException):
self.error_status = error_status
class ProviderCompatibilityException(ProviderException):
"""Provider 兼容性错误异常 - 应该触发故障转移
用于处理因 Provider 不支持某些参数或功能导致的错误。
这类错误不是用户请求本身的问题,换一个 Provider 可能就能成功,应该触发故障转移。
常见场景:
- Unsupported parameter不支持的参数
- Unsupported model不支持的模型
- Unsupported feature不支持的功能
"""
def __init__(
self,
message: str,
provider_name: Optional[str] = None,
status_code: int = 400,
upstream_error: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
self.upstream_error = upstream_error
super().__init__(
message=message,
provider_name=provider_name,
request_metadata=request_metadata,
)
# 覆盖状态码为 400保持与上游一致
self.status_code = status_code
class UpstreamClientException(ProxyException):
"""上游返回的客户端错误异常 - HTTP 4xx 错误,不应该重试

View File

@@ -9,7 +9,7 @@
输出策略:
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
- 文件: 始终保存 DEBUG 级别保留30天每日轮转
- 文件: 始终保存 DEBUG 级别保留30天按大小轮转 (100MB)
使用方式:
from src.core.logger import logger
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
if IS_DOCKER:
# 生产环境:禁用 backtrace 和 diagnose减少日志噪音
logger.add(
sys.stdout,
format=CONSOLE_FORMAT_PROD,
level=LOG_LEVEL,
filter=_log_filter, # type: ignore[arg-type]
colorize=False,
backtrace=False,
diagnose=False,
)
else:
logger.add(
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
log_dir = PROJECT_ROOT / "logs"
log_dir.mkdir(exist_ok=True)
# 文件日志通用配置
file_log_config = {
"format": FILE_FORMAT,
"filter": _log_filter,
"rotation": "100 MB",
"retention": "30 days",
"compression": "gz",
"enqueue": True,
"encoding": "utf-8",
"catch": True,
}
# 生产环境禁用详细堆栈
if IS_DOCKER:
file_log_config["backtrace"] = False
file_log_config["diagnose"] = False
# 主日志文件 - 所有级别
logger.add(
log_dir / "app.log",
format=FILE_FORMAT,
level="DEBUG",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**file_log_config, # type: ignore[arg-type]
)
# 错误日志文件 - 仅 ERROR 及以上
error_log_config = file_log_config.copy()
error_log_config["rotation"] = "50 MB"
logger.add(
log_dir / "error.log",
format=FILE_FORMAT,
level="ERROR",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**error_log_config, # type: ignore[arg-type]
)
# ============================================================================

View File

@@ -5,6 +5,7 @@
import time
from typing import AsyncGenerator, Generator, Optional
from starlette.requests import Request
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
@@ -150,9 +151,22 @@ def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes)
total_estimated = theoretical * workers
logger.info("数据库连接池配置")
if total_estimated > config.db_pool_warn_threshold:
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
safe_limit = config.pg_max_connections - config.pg_reserved_connections
logger.info(
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
config.db_pool_size,
config.db_max_overflow,
workers,
total_estimated,
safe_limit,
)
if total_estimated > safe_limit:
logger.warning(
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved)"
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
total_estimated,
safe_limit,
)
def _ensure_async_engine() -> AsyncEngine:
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
# 创建异步引擎
_async_engine = create_async_engine(
ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
# AsyncEngine 不能使用 QueuePool默认使用 AsyncAdaptedQueuePool
pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout,
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
"""获取异步数据库会话
.. deprecated::
此方法已废弃,项目统一使用同步 Session。
未来版本可能移除此方法。请使用 get_db() 代替。
"""
import warnings
warnings.warn(
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
DeprecationWarning,
stacklevel=2,
)
# 确保异步引擎已初始化
_ensure_async_engine()
@@ -220,16 +245,73 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
def get_db() -> Generator[Session, None, None]:
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
事务策略说明
============
本项目采用**混合事务管理**策略:
1. **LLM 请求路径**
- 由 PluginMiddleware 统一管理事务
- Service 层使用 db.flush() 使更改可见,但不提交
- 请求结束时由中间件统一 commit 或 rollback
- 例外UsageService.record_usage() 会显式 commit因为使用记录需要立即持久化
2. **管理后台 API**
- 路由层显式调用 db.commit()
- 提交后设置 request.state.tx_committed_by_route = True
- 中间件看到此标志后跳过 commit只负责 close
3. **后台任务/调度器**
- 使用独立 Session通过 create_session() 或 next(get_db())
- 自行管理事务生命周期
使用方式
========
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
路由层提交事务示例
==================
```python
@router.post("/example")
async def example(request: Request, db: Session = Depends(get_db)):
# ... 业务逻辑 ...
db.commit()
request.state.tx_committed_by_route = True # 告知中间件已提交
return {"message": "success"}
```
注意事项
========
- 本函数不自动提交事务
- 异常时会自动回滚
- 中间件管理模式下session 关闭由中间件负责
"""
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
if request is not None:
existing_db = getattr(getattr(request, "state", None), "db", None)
if isinstance(existing_db, Session):
yield existing_db
return
# 确保引擎已初始化
_ensure_engine()
db = _SessionLocal()
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state
# 并由中间件负责 commit/rollback/close这里不关闭避免流式响应提前释放会话
managed_by_middleware = bool(
request is not None
and hasattr(request, "state")
and getattr(request.state, "db_managed_by_middleware", False)
)
if managed_by_middleware:
request.state.db = db
db.info["managed_by_middleware"] = True
try:
yield db
# 不再自动 commit由业务代码显式管理事务
@@ -241,12 +323,13 @@ def get_db() -> Generator[Session, None, None]:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise
finally:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
if not managed_by_middleware:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def create_session() -> Session:
@@ -273,16 +356,17 @@ def get_db_url() -> str:
def init_db():
"""初始化数据库"""
"""初始化数据库
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
"""
logger.info("初始化数据库...")
# 确保引擎已创建
engine = _ensure_engine()
_ensure_engine()
# 创建所有表
Base.metadata.create_all(bind=engine)
# 数据库表已通过SQLAlchemy自动创建
# 数据库表结构由 Alembic 迁移管理
# 首次部署或更新后请运行: ./migrate.sh
db = _SessionLocal()
try:
@@ -335,7 +419,7 @@ def init_admin_user(db: Session):
admin.set_password(config.admin_password)
db.add(admin)
db.commit() # 刷新以获取ID但不提交
db.flush() # 分配ID但不提交事务(由外层 init_db 统一 commit
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e:

View File

@@ -3,15 +3,11 @@
采用模块化架构设计
"""
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from src.api.admin import router as admin_router
from src.api.announcements import router as announcement_router
@@ -39,20 +35,18 @@ async def initialize_providers():
"""从数据库初始化提供商(仅用于日志记录)"""
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.database import get_db
from src.database.database import create_session
from src.models.database import Provider
try:
# 创建数据库会话
db_gen = get_db()
db: Session = next(db_gen)
db: Session = create_session()
try:
# 从数据库加载所有活跃的提供商
providers = (
db.query(Provider)
.filter(Provider.is_active == True)
.filter(Provider.is_active.is_(True))
.order_by(Provider.provider_priority.asc())
.all()
)
@@ -75,7 +69,7 @@ async def initialize_providers():
finally:
db.close()
except Exception as e:
except Exception:
logger.exception("从数据库初始化提供商失败")
@@ -125,6 +119,7 @@ async def lifespan(app: FastAPI):
logger.info("初始化全局Redis客户端...")
from src.clients.redis_client import get_redis_client
redis_client = None
try:
redis_client = await get_redis_client(require_redis=config.require_redis)
if redis_client:
@@ -136,6 +131,7 @@ async def lifespan(app: FastAPI):
logger.exception("[ERROR] Redis连接失败应用启动中止")
raise
logger.warning(f"Redis连接失败但配置允许降级将继续使用内存模式: {e}")
redis_client = None
# 初始化并发管理器内部会使用Redis
logger.info("初始化并发管理器...")
@@ -300,33 +296,6 @@ app.include_router(dashboard_router) # 仪表盘端点
app.include_router(public_router) # 公开API端点用户可查看提供商和模型
app.include_router(monitoring_router) # 监控端点
# 静态文件服务(前端构建产物)
# 检查前端构建目录是否存在
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
if frontend_dist.exists():
# 挂载静态资源目录
app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets")
# SPA catch-all路由 - 必须放在最后
@app.get("/{full_path:path}")
async def serve_spa(request: Request, full_path: str):
"""
处理所有未匹配的GET请求返回index.html供前端路由处理
仅对非API路径生效
"""
# 如果是API路径不处理
if full_path.startswith("api/") or full_path.startswith("v1/"):
raise HTTPException(status_code=404, detail="Not Found")
# 返回index.html让前端路由处理
index_file = frontend_dist / "index.html"
if index_file.exists():
return FileResponse(str(index_file))
else:
raise HTTPException(status_code=404, detail="Frontend not built")
else:
logger.warning("前端构建目录不存在,前端路由将无法使用")
def main():

View File

@@ -1,38 +1,43 @@
"""
统一的插件中间件
统一的插件中间件(纯 ASGI 实现)
负责协调所有插件的调用
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题。
"""
import hashlib
import time
from typing import Any, Awaitable, Callable, Optional
from typing import Optional
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response as StarletteResponse
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from src.config import config
from src.core.logger import logger
from src.database import get_db
from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware(BaseHTTPMiddleware):
class PluginMiddleware:
"""
统一的插件调用中间件
统一的插件调用中间件(纯 ASGI 实现)
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
- 纯 ASGI 可以直接透传流式响应,无额外开销
"""
def __init__(self, app: Any) -> None:
super().__init__(app)
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
@@ -62,175 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
"/v1/completions",
]
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
) -> StarletteResponse:
"""处理请求并调用相应插件"""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
db_func = get_db
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
if get_db in request.app.dependency_overrides:
db_func = request.app.dependency_overrides[get_db]
logger.debug("Using overridden get_db from app.dependency_overrides")
# 1. 限流检查(在调用下游之前
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
# 创建数据库会话供需要的插件或后续处理使
db_gen = db_func()
db = None
response = None
exception_to_raise = None
# 2. 预处理插件调
await self._call_pre_request_plugins(request)
# 用于捕获响应状态码
response_status_code: int = 0
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
await send(message)
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
# 获取数据库会话
db = next(db_gen)
request.state.db = db
# 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
headers = rate_limit_result.headers or {}
raise HTTPException(
status_code=429,
detail=rate_limit_result.message or "Rate limit exceeded",
headers=headers,
)
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 处理请求
response = await call_next(request)
# 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try:
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
db.rollback()
# 返回 500 错误,因为数据可能不一致
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "database_error",
"message": "数据保存失败,请重试",
},
},
)
# 跳过后处理插件,直接返回错误响应
return response
# 4. 后处理插件调用(监控等,非关键操作)
# 这些操作失败不应影响用户响应
await self._call_post_request_plugins(request, response, start_time)
# 注意不在此处添加限流响应头因为在BaseHTTPMiddleware中
# 响应返回后修改headers会导致Content-Length不匹配错误
# 限流响应头已在返回429错误时正确包含见上面的HTTPException
except RuntimeError as e:
if str(e) == "No response returned.":
if db:
db.rollback()
logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time)
if db:
try:
db.commit()
except Exception:
pass
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "internal_error",
"message": "Internal server error: downstream handler returned no response.",
},
},
)
else:
exception_to_raise = e
await self.app(scope, receive, send_wrapper)
except Exception as e:
# 回滚数据库事务
if db:
db.rollback()
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
# 尝试提交错误日志
if db:
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
事务策略:
- 如果 request.state.tx_committed_by_route 为 True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
这避免了双重提交的问题,同时保持向后兼容。
"""
from sqlalchemy.orm import Session
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
try:
db.commit()
except:
pass
exception_to_raise = e
finally:
# 确保数据库会话被正确关闭
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
if db is not None:
try:
# 检查会话是否可以安全地进行回滚
# 只有当没有进行中的事务操作时才尝试回滚
if db.is_active and not db.get_transaction().is_active:
# 事务不在活跃状态,可以安全回滚
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
elif db.is_active:
# 事务在活跃状态,尝试回滚
try:
db.rollback()
except Exception as rollback_error:
# 回滚失败(可能是 commit 正在进行中),忽略错误
logger.debug(f"Rollback skipped: {rollback_error}")
except Exception:
# 检查状态时出错,忽略
pass
# 通过触发生成器的 finally 块来关闭会话(标准模式)
# 这会调用 get_db() 的 finally 块,执行 db.close()
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
finally:
# 关闭会话,归还连接到连接池
try:
next(db_gen, None)
except StopIteration:
# 正常情况:生成器已耗尽
pass
except Exception as cleanup_error:
# 忽略 IllegalStateChangeError 等清理错误
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
logger.warning(f"Database cleanup warning: {cleanup_error}")
# 在 finally 块之后处理异常和响应
if exception_to_raise:
raise exception_to_raise
return response
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 可能包含多个 IP取第一个
return forwarded_for.split(",")[0].strip()
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
@@ -250,7 +239,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
return False
async def _get_rate_limit_key_and_config(
self, request: Request, db: Session
self, request: Request
) -> tuple[Optional[str], Optional[int]]:
"""
获取速率限制的key和配置
@@ -272,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.startswith("Bearer "):
if auth_header.lower().startswith("bearer "):
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
import hashlib
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
@@ -318,14 +305,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 如果没有限流插件,允许通过
return None
# 获取数据库会话
db = getattr(request.state, "db", None)
if not db:
logger.warning("速率限制检查:无法获取数据库会话")
return None
# 获取速率限制的key和配置从数据库
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
# 获取速率限制的 key 和配置
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
if not key:
# 不需要限流的端点(如未分类路径),静默跳过
return None
@@ -336,7 +317,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
key=key,
endpoint=request.url.path,
method=request.method,
rate_limit=rate_limit_value, # 传入数据库配置的限制值
rate_limit=rate_limit_value, # 传入配置的限制值
)
# 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult):
@@ -349,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
)
else:
# 限流触发,记录日志
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
return result
return None
except Exception as e:
@@ -362,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
pass
async def _call_post_request_plugins(
self, request: Request, response: StarletteResponse, start_time: float
self, request: Request, status_code: int, start_time: float
) -> None:
"""调用请求后的插件"""
@@ -375,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(response.status_code),
"status_class": f"{response.status_code // 100}xx",
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
# 记录请求计数
@@ -398,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
duration = time.time() - start_time
@@ -410,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": request.state.request_id,
"request_id": getattr(request.state, "request_id", ""),
"duration": duration,
},
)

Some files were not shown because too many files have changed in this diff Show More