mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fa64b98e3 | ||
|
|
75d7e89cbb | ||
|
|
d73a443484 | ||
|
|
15a9b88fc8 | ||
|
|
03eb7203ec | ||
|
|
e38cd6819b | ||
|
|
d44cfaddf6 | ||
|
|
65225710a8 | ||
|
|
d7f5b16359 | ||
|
|
7185818724 | ||
|
|
868f3349e5 | ||
|
|
d7384e69d9 | ||
|
|
1d5c378343 | ||
|
|
4e1aed9976 | ||
|
|
e2e7996a54 | ||
|
|
df9f9a9f4f | ||
|
|
7553b0da80 | ||
|
|
8f30bf0bef | ||
|
|
8c12174521 | ||
|
|
6aa1876955 | ||
|
|
7f07122aea | ||
|
|
c2ddc6bd3c | ||
|
|
af476ff21e | ||
|
|
3bbc1c6b66 | ||
|
|
c69a0a8506 | ||
|
|
1fae202bde | ||
|
|
b9a26c4550 | ||
|
|
e42bd35d48 | ||
|
|
f22a073fd9 | ||
|
|
5c7ad089d2 | ||
|
|
97425ac68f | ||
|
|
912f6643e2 | ||
|
|
6c0373fda6 | ||
|
|
070121717d | ||
|
|
85fafeacb8 | ||
|
|
daf8b870f0 | ||
|
|
880fb61c66 | ||
|
|
7e792dabfc | ||
|
|
cd06169b2f | ||
|
|
50ffd47546 |
10
.env.example
10
.env.example
@@ -1,8 +1,16 @@
|
||||
# ==================== 必须配置(启动前) ====================
|
||||
# 以下配置项必须在项目启动前设置
|
||||
|
||||
# 数据库密码
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_NAME=aether
|
||||
DB_PASSWORD=your_secure_password_here
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=your_redis_password_here
|
||||
|
||||
# JWT密钥(使用 python generate_keys.py 生成)
|
||||
|
||||
39
.github/workflows/docker-publish.yml
vendored
39
.github/workflows/docker-publish.yml
vendored
@@ -15,6 +15,8 @@ env:
|
||||
REGISTRY: ghcr.io
|
||||
BASE_IMAGE_NAME: fawney19/aether-base
|
||||
APP_IMAGE_NAME: fawney19/aether
|
||||
# Files that affect base image - used for hash calculation
|
||||
BASE_FILES: "Dockerfile.base pyproject.toml frontend/package.json frontend/package-lock.json"
|
||||
|
||||
jobs:
|
||||
check-base-changes:
|
||||
@@ -23,8 +25,13 @@ jobs:
|
||||
base_changed: ${{ steps.check.outputs.base_changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check if base image needs rebuild
|
||||
id: check
|
||||
@@ -34,10 +41,26 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check if base-related files changed
|
||||
if git diff --name-only HEAD~1 HEAD | grep -qE '^(Dockerfile\.base|pyproject\.toml|frontend/package.*\.json)$'; then
|
||||
# Calculate current hash of base-related files
|
||||
CURRENT_HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
|
||||
echo "Current base files hash: $CURRENT_HASH"
|
||||
|
||||
# Try to get hash label from remote image config
|
||||
# Pull the image config and extract labels
|
||||
REMOTE_HASH=""
|
||||
if docker pull ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest 2>/dev/null; then
|
||||
REMOTE_HASH=$(docker inspect ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest --format '{{ index .Config.Labels "org.opencontainers.image.base.hash" }}' 2>/dev/null) || true
|
||||
fi
|
||||
|
||||
if [ -z "$REMOTE_HASH" ] || [ "$REMOTE_HASH" == "<no value>" ]; then
|
||||
# No remote image or no hash label, need to rebuild
|
||||
echo "No remote base image or hash label found, need rebuild"
|
||||
echo "base_changed=true" >> $GITHUB_OUTPUT
|
||||
elif [ "$CURRENT_HASH" != "$REMOTE_HASH" ]; then
|
||||
echo "Hash mismatch: remote=$REMOTE_HASH, current=$CURRENT_HASH"
|
||||
echo "base_changed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Hash matches, no rebuild needed"
|
||||
echo "base_changed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
@@ -61,6 +84,12 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Calculate base files hash
|
||||
id: hash
|
||||
run: |
|
||||
HASH=$(cat ${{ env.BASE_FILES }} 2>/dev/null | sha256sum | cut -d' ' -f1)
|
||||
echo "hash=$HASH" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Extract metadata for base image
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -69,6 +98,8 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=latest
|
||||
type=sha,prefix=
|
||||
labels: |
|
||||
org.opencontainers.image.base.hash=${{ steps.hash.outputs.hash }}
|
||||
|
||||
- name: Build and push base image
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -117,7 +148,7 @@ jobs:
|
||||
|
||||
- name: Update Dockerfile.app to use registry base image
|
||||
run: |
|
||||
sed -i "s|FROM aether-base:latest|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest|g" Dockerfile.app
|
||||
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
|
||||
|
||||
- name: Build and push app image
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
132
Dockerfile.app
132
Dockerfile.app
@@ -1,16 +1,134 @@
|
||||
# 应用镜像:基于基础镜像,只复制代码(秒级构建)
|
||||
# 运行镜像:从 base 提取产物到精简运行时
|
||||
# 构建命令: docker build -f Dockerfile.app -t aether-app:latest .
|
||||
FROM aether-base:latest
|
||||
# 用于 GitHub Actions CI(官方源)
|
||||
FROM aether-base:latest AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制前端源码并构建
|
||||
COPY frontend/ ./frontend/
|
||||
RUN cd frontend && npm run build
|
||||
|
||||
# ==================== 运行时镜像 ====================
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 运行时依赖(无 gcc/nodejs/npm)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
nginx \
|
||||
supervisor \
|
||||
libpq5 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 从 base 镜像复制 Python 包
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
|
||||
# 只复制需要的 Python 可执行文件
|
||||
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
|
||||
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
|
||||
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
|
||||
|
||||
# 从 builder 阶段复制前端构建产物
|
||||
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
|
||||
|
||||
# 复制后端代码
|
||||
COPY src/ ./src/
|
||||
COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# 构建前端(使用基础镜像中已安装的 node_modules)
|
||||
COPY frontend/ /tmp/frontend/
|
||||
RUN cd /tmp/frontend && npm run build && \
|
||||
cp -r dist/* /usr/share/nginx/html/ && \
|
||||
rm -rf /tmp/frontend
|
||||
# Nginx 配置模板
|
||||
RUN printf '%s\n' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
' root /usr/share/nginx/html;' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
' try_files $uri =404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(src|node_modules)/ {' \
|
||||
' deny all;' \
|
||||
' return 404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(dashboard|admin|login)(/|$) {' \
|
||||
' try_files $uri $uri/ /index.html;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location / {' \
|
||||
' try_files $uri $uri/ @backend;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location @backend {' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
' proxy_set_header Content-Type $content_type;' \
|
||||
' proxy_set_header Authorization $http_authorization;' \
|
||||
' proxy_set_header X-Api-Key $http_x_api_key;' \
|
||||
' proxy_buffering off;' \
|
||||
' proxy_cache off;' \
|
||||
' proxy_request_buffering off;' \
|
||||
' chunked_transfer_encoding on;' \
|
||||
' gzip off;' \
|
||||
' add_header X-Accel-Buffering no;' \
|
||||
' proxy_connect_timeout 600s;' \
|
||||
' proxy_send_timeout 600s;' \
|
||||
' proxy_read_timeout 600s;' \
|
||||
' }' \
|
||||
'}' > /etc/nginx/sites-available/default.template
|
||||
|
||||
# Supervisor 配置
|
||||
RUN printf '%s\n' \
|
||||
'[supervisord]' \
|
||||
'nodaemon=true' \
|
||||
'logfile=/var/log/supervisor/supervisord.log' \
|
||||
'pidfile=/var/run/supervisord.pid' \
|
||||
'' \
|
||||
'[program:nginx]' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/var/log/nginx/access.log' \
|
||||
'stderr_logfile=/var/log/nginx/error.log' \
|
||||
'' \
|
||||
'[program:app]' \
|
||||
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'directory=/app' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/dev/stdout' \
|
||||
'stdout_logfile_maxbytes=0' \
|
||||
'stderr_logfile=/dev/stderr' \
|
||||
'stderr_logfile_maxbytes=0' \
|
||||
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONIOENCODING=utf-8 \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PORT=8084
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
|
||||
135
Dockerfile.app.local
Normal file
135
Dockerfile.app.local
Normal file
@@ -0,0 +1,135 @@
|
||||
# 运行镜像:从 base 提取产物到精简运行时(国内镜像源版本)
|
||||
# 构建命令: docker build -f Dockerfile.app.local -t aether-app:latest .
|
||||
# 用于本地/国内服务器部署
|
||||
FROM aether-base:latest AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制前端源码并构建
|
||||
COPY frontend/ ./frontend/
|
||||
RUN cd frontend && npm run build
|
||||
|
||||
# ==================== 运行时镜像 ====================
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 运行时依赖(使用清华镜像源)
|
||||
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
|
||||
apt-get update && apt-get install -y \
|
||||
nginx \
|
||||
supervisor \
|
||||
libpq5 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 从 base 镜像复制 Python 包
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
|
||||
# 只复制需要的 Python 可执行文件
|
||||
COPY --from=builder /usr/local/bin/gunicorn /usr/local/bin/
|
||||
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/
|
||||
COPY --from=builder /usr/local/bin/alembic /usr/local/bin/
|
||||
|
||||
# 从 builder 阶段复制前端构建产物
|
||||
COPY --from=builder /app/frontend/dist /usr/share/nginx/html
|
||||
|
||||
# 复制后端代码
|
||||
COPY src/ ./src/
|
||||
COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# Nginx 配置模板
|
||||
RUN printf '%s\n' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
' root /usr/share/nginx/html;' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
' try_files $uri =404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(src|node_modules)/ {' \
|
||||
' deny all;' \
|
||||
' return 404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(dashboard|admin|login)(/|$) {' \
|
||||
' try_files $uri $uri/ /index.html;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location / {' \
|
||||
' try_files $uri $uri/ @backend;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location @backend {' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
' proxy_set_header Content-Type $content_type;' \
|
||||
' proxy_set_header Authorization $http_authorization;' \
|
||||
' proxy_set_header X-Api-Key $http_x_api_key;' \
|
||||
' proxy_buffering off;' \
|
||||
' proxy_cache off;' \
|
||||
' proxy_request_buffering off;' \
|
||||
' chunked_transfer_encoding on;' \
|
||||
' gzip off;' \
|
||||
' add_header X-Accel-Buffering no;' \
|
||||
' proxy_connect_timeout 600s;' \
|
||||
' proxy_send_timeout 600s;' \
|
||||
' proxy_read_timeout 600s;' \
|
||||
' }' \
|
||||
'}' > /etc/nginx/sites-available/default.template
|
||||
|
||||
# Supervisor 配置
|
||||
RUN printf '%s\n' \
|
||||
'[supervisord]' \
|
||||
'nodaemon=true' \
|
||||
'logfile=/var/log/supervisor/supervisord.log' \
|
||||
'pidfile=/var/run/supervisord.pid' \
|
||||
'' \
|
||||
'[program:nginx]' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/var/log/nginx/access.log' \
|
||||
'stderr_logfile=/var/log/nginx/error.log' \
|
||||
'' \
|
||||
'[program:app]' \
|
||||
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'directory=/app' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/dev/stdout' \
|
||||
'stdout_logfile_maxbytes=0' \
|
||||
'stderr_logfile=/dev/stderr' \
|
||||
'stderr_logfile_maxbytes=0' \
|
||||
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONIOENCODING=utf-8 \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PORT=8084
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
117
Dockerfile.base
117
Dockerfile.base
@@ -1,122 +1,25 @@
|
||||
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
|
||||
# 构建镜像:编译环境 + 预编译的依赖
|
||||
# 用于 GitHub Actions CI 构建(不使用国内镜像源)
|
||||
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
|
||||
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 系统依赖
|
||||
# 构建工具
|
||||
RUN apt-get update && apt-get install -y \
|
||||
nginx \
|
||||
supervisor \
|
||||
libpq-dev \
|
||||
gcc \
|
||||
curl \
|
||||
gettext-base \
|
||||
nodejs \
|
||||
npm \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Python 依赖(安装到系统,不用 -e 模式)
|
||||
# Python 依赖
|
||||
COPY pyproject.toml README.md ./
|
||||
RUN mkdir -p src && touch src/__init__.py && \
|
||||
pip install --no-cache-dir .
|
||||
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
|
||||
pip cache purge
|
||||
|
||||
# 前端依赖
|
||||
COPY frontend/package*.json /tmp/frontend/
|
||||
WORKDIR /tmp/frontend
|
||||
RUN npm ci
|
||||
|
||||
# Nginx 配置模板
|
||||
RUN printf '%s\n' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
' root /usr/share/nginx/html;' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
' try_files $uri =404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(src|node_modules)/ {' \
|
||||
' deny all;' \
|
||||
' return 404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(dashboard|admin|login)(/|$) {' \
|
||||
' try_files $uri $uri/ /index.html;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location / {' \
|
||||
' try_files $uri $uri/ @backend;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location @backend {' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
' proxy_set_header Content-Type $content_type;' \
|
||||
' proxy_set_header Authorization $http_authorization;' \
|
||||
' proxy_set_header X-Api-Key $http_x_api_key;' \
|
||||
' proxy_buffering off;' \
|
||||
' proxy_cache off;' \
|
||||
' proxy_request_buffering off;' \
|
||||
' chunked_transfer_encoding on;' \
|
||||
' proxy_connect_timeout 600s;' \
|
||||
' proxy_send_timeout 600s;' \
|
||||
' proxy_read_timeout 600s;' \
|
||||
' }' \
|
||||
'}' > /etc/nginx/sites-available/default.template
|
||||
|
||||
# Supervisor 配置
|
||||
RUN printf '%s\n' \
|
||||
'[supervisord]' \
|
||||
'nodaemon=true' \
|
||||
'logfile=/var/log/supervisor/supervisord.log' \
|
||||
'pidfile=/var/run/supervisord.pid' \
|
||||
'' \
|
||||
'[program:nginx]' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/var/log/nginx/access.log' \
|
||||
'stderr_logfile=/var/log/nginx/error.log' \
|
||||
'' \
|
||||
'[program:app]' \
|
||||
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'directory=/app' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/dev/stdout' \
|
||||
'stdout_logfile_maxbytes=0' \
|
||||
'stderr_logfile=/dev/stderr' \
|
||||
'stderr_logfile_maxbytes=0' \
|
||||
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONIOENCODING=utf-8 \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PORT=8084
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
# 前端依赖(只安装,不构建)
|
||||
COPY frontend/package*.json ./frontend/
|
||||
RUN cd frontend && npm ci
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
# 基础镜像:包含所有依赖,只在依赖变化时需要重建
|
||||
# 构建命令: docker build -f Dockerfile.base -t aether-base:latest .
|
||||
# 构建镜像:编译环境 + 预编译的依赖(国内镜像源版本)
|
||||
# 构建命令: docker build -f Dockerfile.base.local -t aether-base:latest .
|
||||
# 只在 pyproject.toml 或 frontend/package*.json 变化时需要重建
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 系统依赖
|
||||
# 构建工具(使用清华镜像源)
|
||||
RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list.d/debian.sources && \
|
||||
apt-get update && apt-get install -y \
|
||||
nginx \
|
||||
supervisor \
|
||||
libpq-dev \
|
||||
gcc \
|
||||
curl \
|
||||
gettext-base \
|
||||
nodejs \
|
||||
npm \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
@@ -20,107 +17,12 @@ RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.li
|
||||
# pip 镜像源
|
||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# Python 依赖(安装到系统,不用 -e 模式)
|
||||
# Python 依赖
|
||||
COPY pyproject.toml README.md ./
|
||||
RUN mkdir -p src && touch src/__init__.py && \
|
||||
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir .
|
||||
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir . && \
|
||||
pip cache purge
|
||||
|
||||
# 前端依赖
|
||||
COPY frontend/package*.json /tmp/frontend/
|
||||
WORKDIR /tmp/frontend
|
||||
RUN npm config set registry https://registry.npmmirror.com && npm ci
|
||||
|
||||
# Nginx 配置模板
|
||||
RUN printf '%s\n' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
' root /usr/share/nginx/html;' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
' try_files $uri =404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(src|node_modules)/ {' \
|
||||
' deny all;' \
|
||||
' return 404;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(dashboard|admin|login)(/|$) {' \
|
||||
' try_files $uri $uri/ /index.html;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location / {' \
|
||||
' try_files $uri $uri/ @backend;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location @backend {' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
' proxy_set_header Content-Type $content_type;' \
|
||||
' proxy_set_header Authorization $http_authorization;' \
|
||||
' proxy_set_header X-Api-Key $http_x_api_key;' \
|
||||
' proxy_buffering off;' \
|
||||
' proxy_cache off;' \
|
||||
' proxy_request_buffering off;' \
|
||||
' chunked_transfer_encoding on;' \
|
||||
' proxy_connect_timeout 600s;' \
|
||||
' proxy_send_timeout 600s;' \
|
||||
' proxy_read_timeout 600s;' \
|
||||
' }' \
|
||||
'}' > /etc/nginx/sites-available/default.template
|
||||
|
||||
# Supervisor 配置
|
||||
RUN printf '%s\n' \
|
||||
'[supervisord]' \
|
||||
'nodaemon=true' \
|
||||
'logfile=/var/log/supervisor/supervisord.log' \
|
||||
'pidfile=/var/run/supervisord.pid' \
|
||||
'' \
|
||||
'[program:nginx]' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/var/log/nginx/access.log' \
|
||||
'stderr_logfile=/var/log/nginx/error.log' \
|
||||
'' \
|
||||
'[program:app]' \
|
||||
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'directory=/app' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/dev/stdout' \
|
||||
'stdout_logfile_maxbytes=0' \
|
||||
'stderr_logfile=/dev/stderr' \
|
||||
'stderr_logfile_maxbytes=0' \
|
||||
'environment=PYTHONUNBUFFERED=1,PYTHONIOENCODING=utf-8,LANG=C.UTF-8,LC_ALL=C.UTF-8,DOCKER_CONTAINER=true' > /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data /usr/share/nginx/html
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONIOENCODING=utf-8 \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PORT=8084
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
# 前端依赖(只安装,不构建,使用淘宝镜像源)
|
||||
COPY frontend/package*.json ./frontend/
|
||||
RUN cd frontend && npm config set registry https://registry.npmmirror.com && npm ci
|
||||
|
||||
@@ -394,6 +394,10 @@ def upgrade() -> None:
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
# usage 表复合索引(优化常见查询)
|
||||
op.create_index("idx_usage_user_created", "usage", ["user_id", "created_at"])
|
||||
op.create_index("idx_usage_apikey_created", "usage", ["api_key_id", "created_at"])
|
||||
op.create_index("idx_usage_provider_model_created", "usage", ["provider", "model", "created_at"])
|
||||
|
||||
# ==================== user_quotas ====================
|
||||
op.create_table(
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""add stats_daily_model table and rename provider_model_aliases
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: f30f9936f6a2
|
||||
Create Date: 2025-12-20 12:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a1b2c3d4e5f6'
|
||||
down_revision = 'f30f9936f6a2'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
return table_name in inspector.get_table_names()
|
||||
|
||||
|
||||
def column_exists(table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||||
return column_name in columns
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""创建 stats_daily_model 表,重命名 provider_model_aliases 为 provider_model_mappings"""
|
||||
# 1. 创建 stats_daily_model 表
|
||||
if not table_exists('stats_daily_model'):
|
||||
op.create_table(
|
||||
'stats_daily_model',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('date', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('model', sa.String(100), nullable=False),
|
||||
sa.Column('total_requests', sa.Integer(), nullable=False, default=0),
|
||||
sa.Column('input_tokens', sa.BigInteger(), nullable=False, default=0),
|
||||
sa.Column('output_tokens', sa.BigInteger(), nullable=False, default=0),
|
||||
sa.Column('cache_creation_tokens', sa.BigInteger(), nullable=False, default=0),
|
||||
sa.Column('cache_read_tokens', sa.BigInteger(), nullable=False, default=0),
|
||||
sa.Column('total_cost', sa.Float(), nullable=False, default=0.0),
|
||||
sa.Column('avg_response_time_ms', sa.Float(), nullable=False, default=0.0),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False,
|
||||
server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False,
|
||||
server_default=sa.func.now(), onupdate=sa.func.now()),
|
||||
sa.UniqueConstraint('date', 'model', name='uq_stats_daily_model'),
|
||||
)
|
||||
|
||||
# 创建索引
|
||||
op.create_index('idx_stats_daily_model_date', 'stats_daily_model', ['date'])
|
||||
op.create_index('idx_stats_daily_model_date_model', 'stats_daily_model', ['date', 'model'])
|
||||
|
||||
# 2. 重命名 models 表的 provider_model_aliases 为 provider_model_mappings
|
||||
if column_exists('models', 'provider_model_aliases') and not column_exists('models', 'provider_model_mappings'):
|
||||
op.alter_column('models', 'provider_model_aliases', new_column_name='provider_model_mappings')
|
||||
|
||||
|
||||
def index_exists(table_name: str, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
indexes = [idx['name'] for idx in inspector.get_indexes(table_name)]
|
||||
return index_name in indexes
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""删除 stats_daily_model 表,恢复 provider_model_aliases 列名"""
|
||||
# 恢复列名
|
||||
if column_exists('models', 'provider_model_mappings') and not column_exists('models', 'provider_model_aliases'):
|
||||
op.alter_column('models', 'provider_model_mappings', new_column_name='provider_model_aliases')
|
||||
|
||||
# 删除表
|
||||
if table_exists('stats_daily_model'):
|
||||
if index_exists('stats_daily_model', 'idx_stats_daily_model_date_model'):
|
||||
op.drop_index('idx_stats_daily_model_date_model', table_name='stats_daily_model')
|
||||
if index_exists('stats_daily_model', 'idx_stats_daily_model_date'):
|
||||
op.drop_index('idx_stats_daily_model_date', table_name='stats_daily_model')
|
||||
op.drop_table('stats_daily_model')
|
||||
@@ -0,0 +1,79 @@
|
||||
"""add usage table composite indexes for query optimization
|
||||
|
||||
Revision ID: b2c3d4e5f6g7
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-20 15:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b2c3d4e5f6g7'
|
||||
down_revision = 'a1b2c3d4e5f6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""为 usage 表添加复合索引以优化常见查询
|
||||
|
||||
使用 CONCURRENTLY 创建索引以避免锁表,
|
||||
但需要在 AUTOCOMMIT 模式下执行(不能在事务内)
|
||||
|
||||
注意:如果是从全新数据库执行(baseline 刚创建表),
|
||||
由于 AUTOCOMMIT 连接看不到事务中未提交的表,会跳过索引创建。
|
||||
这种情况下索引会在下次迁移或手动创建。
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
engine = conn.engine
|
||||
|
||||
# 使用新连接并设置 AUTOCOMMIT 模式以支持 CREATE INDEX CONCURRENTLY
|
||||
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as autocommit_conn:
|
||||
# 检查 usage 表是否存在(在 AUTOCOMMIT 连接中可见)
|
||||
# 如果表不存在(例如 baseline 迁移还在事务中),跳过索引创建
|
||||
result = autocommit_conn.execute(text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'usage')"
|
||||
))
|
||||
table_exists = result.scalar()
|
||||
|
||||
if not table_exists:
|
||||
# 表在当前连接不可见(可能 baseline 还在事务中),跳过
|
||||
# 索引将通过后续迁移或手动创建
|
||||
return
|
||||
|
||||
# 使用 IF NOT EXISTS 避免重复创建,无需单独检查索引是否存在
|
||||
|
||||
# 1. user_id + created_at 复合索引 (用户用量查询)
|
||||
autocommit_conn.execute(text(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_user_created "
|
||||
"ON usage (user_id, created_at)"
|
||||
))
|
||||
|
||||
# 2. api_key_id + created_at 复合索引 (API Key 用量查询)
|
||||
autocommit_conn.execute(text(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_apikey_created "
|
||||
"ON usage (api_key_id, created_at)"
|
||||
))
|
||||
|
||||
# 3. provider + model + created_at 复合索引 (模型统计查询)
|
||||
autocommit_conn.execute(text(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_provider_model_created "
|
||||
"ON usage (provider, model, created_at)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""删除复合索引"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 使用 IF EXISTS 避免索引不存在时报错
|
||||
conn.execute(text(
|
||||
"DROP INDEX IF EXISTS idx_usage_provider_model_created"
|
||||
))
|
||||
conn.execute(text(
|
||||
"DROP INDEX IF EXISTS idx_usage_apikey_created"
|
||||
))
|
||||
conn.execute(text(
|
||||
"DROP INDEX IF EXISTS idx_usage_user_created"
|
||||
))
|
||||
40
deploy.sh
40
deploy.sh
@@ -21,15 +21,18 @@ HASH_FILE=".deps-hash"
|
||||
CODE_HASH_FILE=".code-hash"
|
||||
MIGRATION_HASH_FILE=".migration-hash"
|
||||
|
||||
# 计算依赖文件的哈希值
|
||||
# 计算依赖文件的哈希值(包含 Dockerfile.base.local)
|
||||
calc_deps_hash() {
|
||||
cat pyproject.toml frontend/package.json frontend/package-lock.json 2>/dev/null | md5sum | cut -d' ' -f1
|
||||
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
|
||||
}
|
||||
|
||||
# 计算代码文件的哈希值
|
||||
# 计算代码文件的哈希值(包含 Dockerfile.app.local)
|
||||
calc_code_hash() {
|
||||
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1
|
||||
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1
|
||||
{
|
||||
cat Dockerfile.app.local 2>/dev/null
|
||||
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null
|
||||
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null
|
||||
} | md5sum | cut -d' ' -f1
|
||||
}
|
||||
|
||||
# 计算迁移文件的哈希值
|
||||
@@ -88,7 +91,7 @@ build_base() {
|
||||
# 构建应用镜像
|
||||
build_app() {
|
||||
echo ">>> Building app image (code only)..."
|
||||
docker build -f Dockerfile.app -t aether-app:latest .
|
||||
docker build -f Dockerfile.app.local -t aether-app:latest .
|
||||
save_code_hash
|
||||
}
|
||||
|
||||
@@ -162,29 +165,46 @@ git pull
|
||||
|
||||
# 标记是否需要重启
|
||||
NEED_RESTART=false
|
||||
BASE_REBUILT=false
|
||||
|
||||
# 检查基础镜像是否存在,或依赖是否变化
|
||||
if ! docker image inspect aether-base:latest >/dev/null 2>&1; then
|
||||
echo ">>> Base image not found, building..."
|
||||
build_base
|
||||
BASE_REBUILT=true
|
||||
NEED_RESTART=true
|
||||
elif check_deps_changed; then
|
||||
echo ">>> Dependencies changed, rebuilding base image..."
|
||||
build_base
|
||||
BASE_REBUILT=true
|
||||
NEED_RESTART=true
|
||||
else
|
||||
echo ">>> Dependencies unchanged."
|
||||
fi
|
||||
|
||||
# 检查代码是否变化
|
||||
# 检查代码或迁移是否变化,或者 base 重建了(app 依赖 base)
|
||||
# 注意:迁移文件打包在镜像中,所以迁移变化也需要重建 app 镜像
|
||||
MIGRATION_CHANGED=false
|
||||
if check_migration_changed; then
|
||||
MIGRATION_CHANGED=true
|
||||
fi
|
||||
|
||||
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
|
||||
echo ">>> App image not found, building..."
|
||||
build_app
|
||||
NEED_RESTART=true
|
||||
elif [ "$BASE_REBUILT" = true ]; then
|
||||
echo ">>> Base image rebuilt, rebuilding app image..."
|
||||
build_app
|
||||
NEED_RESTART=true
|
||||
elif check_code_changed; then
|
||||
echo ">>> Code changed, rebuilding app image..."
|
||||
build_app
|
||||
NEED_RESTART=true
|
||||
elif [ "$MIGRATION_CHANGED" = true ]; then
|
||||
echo ">>> Migration files changed, rebuilding app image..."
|
||||
build_app
|
||||
NEED_RESTART=true
|
||||
else
|
||||
echo ">>> Code unchanged."
|
||||
fi
|
||||
@@ -197,9 +217,9 @@ else
|
||||
echo ">>> No changes detected, skipping restart."
|
||||
fi
|
||||
|
||||
# 检查迁移变化
|
||||
if check_migration_changed; then
|
||||
echo ">>> Migration files changed, running database migration..."
|
||||
# 检查迁移变化(如果前面已经检测到变化并重建了镜像,这里直接运行迁移)
|
||||
if [ "$MIGRATION_CHANGED" = true ]; then
|
||||
echo ">>> Running database migration..."
|
||||
sleep 3
|
||||
run_migration
|
||||
else
|
||||
|
||||
3
dev.sh
3
dev.sh
@@ -8,7 +8,8 @@ source .env
|
||||
set +a
|
||||
|
||||
# 构建 DATABASE_URL
|
||||
export DATABASE_URL="postgresql://postgres:${DB_PASSWORD}@localhost:5432/aether"
|
||||
export DATABASE_URL="postgresql://${DB_USER:-postgres}:${DB_PASSWORD}@${DB_HOST:-localhost}:${DB_PORT:-5432}/${DB_NAME:-aether}"
|
||||
export REDIS_URL=redis://:${REDIS_PASSWORD}@${REDIS_HOST:-localhost}:${REDIS_PORT:-6379}/0
|
||||
|
||||
# 启动 uvicorn(热重载模式)
|
||||
echo "🚀 启动本地开发服务器..."
|
||||
|
||||
@@ -41,7 +41,7 @@ services:
|
||||
app:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.app
|
||||
dockerfile: Dockerfile.app.local
|
||||
image: aether-app:latest
|
||||
container_name: aether-app
|
||||
environment:
|
||||
|
||||
@@ -112,7 +112,7 @@ export interface KeyExport {
|
||||
export interface ModelExport {
|
||||
global_model_name: string | null
|
||||
provider_model_name: string
|
||||
provider_model_aliases?: any
|
||||
provider_model_mappings?: any
|
||||
price_per_request?: number | null
|
||||
tiered_pricing?: any
|
||||
supports_vision?: boolean | null
|
||||
|
||||
@@ -66,6 +66,7 @@ export interface UserAffinity {
|
||||
key_name: string | null
|
||||
key_prefix: string | null // Provider Key 脱敏显示(前4...后4)
|
||||
rate_multiplier: number
|
||||
global_model_id: string | null // 原始的 global_model_id(用于删除)
|
||||
model_name: string | null // 模型名称(如 claude-haiku-4-5-20250514)
|
||||
model_display_name: string | null // 模型显示名称(如 Claude Haiku 4.5)
|
||||
api_format: string | null // API 格式 (claude/openai)
|
||||
@@ -119,6 +120,18 @@ export const cacheApi = {
|
||||
await api.delete(`/api/admin/monitoring/cache/users/${userIdentifier}`)
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除单条缓存亲和性
|
||||
*
|
||||
* @param affinityKey API Key ID
|
||||
* @param endpointId Endpoint ID
|
||||
* @param modelId GlobalModel ID
|
||||
* @param apiFormat API 格式 (claude/openai)
|
||||
*/
|
||||
async clearSingleAffinity(affinityKey: string, endpointId: string, modelId: string, apiFormat: string): Promise<void> {
|
||||
await api.delete(`/api/admin/monitoring/cache/affinity/${affinityKey}/${endpointId}/${modelId}/${apiFormat}`)
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除所有缓存
|
||||
*/
|
||||
|
||||
@@ -5,6 +5,8 @@ import type {
|
||||
ModelUpdate,
|
||||
ModelCatalogResponse,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
UpstreamModel,
|
||||
ImportFromUpstreamResponse,
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
@@ -119,3 +121,40 @@ export async function batchAssignModelsToProvider(
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询提供商的上游模型列表
|
||||
*/
|
||||
export async function queryProviderUpstreamModels(
|
||||
providerId: string
|
||||
): Promise<{
|
||||
success: boolean
|
||||
data: {
|
||||
models: UpstreamModel[]
|
||||
error: string | null
|
||||
}
|
||||
provider: {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
}
|
||||
}> {
|
||||
const response = await client.post('/api/admin/provider-query/models', {
|
||||
provider_id: providerId,
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 从上游提供商导入模型
|
||||
*/
|
||||
export async function importModelsFromUpstream(
|
||||
providerId: string,
|
||||
modelIds: string[]
|
||||
): Promise<ImportFromUpstreamResponse> {
|
||||
const response = await client.post(
|
||||
`/api/admin/providers/${providerId}/import-from-upstream`,
|
||||
{ model_ids: modelIds }
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -110,6 +110,24 @@ export interface EndpointAPIKey {
|
||||
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
|
||||
}
|
||||
|
||||
export interface EndpointAPIKeyUpdate {
|
||||
name?: string
|
||||
api_key?: string // 仅在需要更新时提供
|
||||
rate_multiplier?: number
|
||||
internal_priority?: number
|
||||
global_priority?: number | null
|
||||
max_concurrent?: number | null // null 表示切换为自适应模式
|
||||
rate_limit?: number
|
||||
daily_limit?: number
|
||||
monthly_limit?: number
|
||||
allowed_models?: string[] | null
|
||||
capabilities?: Record<string, boolean> | null
|
||||
cache_ttl_minutes?: number
|
||||
max_probe_interval_minutes?: number
|
||||
note?: string
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface EndpointHealthDetail {
|
||||
api_format: string
|
||||
health_score: number
|
||||
@@ -244,18 +262,21 @@ export interface ConcurrencyStatus {
|
||||
key_max_concurrent?: number
|
||||
}
|
||||
|
||||
export interface ProviderModelAlias {
|
||||
export interface ProviderModelMapping {
|
||||
name: string
|
||||
priority: number // 优先级(数字越小优先级越高)
|
||||
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||
}
|
||||
|
||||
// 保留别名以保持向后兼容
|
||||
export type ProviderModelAlias = ProviderModelMapping
|
||||
|
||||
export interface Model {
|
||||
id: string
|
||||
provider_id: string
|
||||
global_model_id?: string // 关联的 GlobalModel ID
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
provider_model_mappings?: ProviderModelMapping[] | null // 模型名称映射列表(带优先级)
|
||||
// 原始配置值(可能为空,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number | null // 按次计费价格
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -285,7 +306,7 @@ export interface Model {
|
||||
|
||||
export interface ModelCreate {
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] // 模型名称别名列表(带优先级)
|
||||
provider_model_mappings?: ProviderModelMapping[] // 模型名称映射列表(带优先级)
|
||||
global_model_id: string // 关联的 GlobalModel ID(必填)
|
||||
// 计费配置(可选,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number // 按次计费价格
|
||||
@@ -302,7 +323,7 @@ export interface ModelCreate {
|
||||
|
||||
export interface ModelUpdate {
|
||||
provider_model_name?: string
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
provider_model_mappings?: ProviderModelMapping[] | null // 模型名称映射列表(带优先级)
|
||||
global_model_id?: string
|
||||
price_per_request?: number | null // 按次计费价格(null 表示清空/使用默认值)
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -495,3 +516,42 @@ export interface GlobalModelListResponse {
|
||||
models: GlobalModelResponse[]
|
||||
total: number
|
||||
}
|
||||
|
||||
// ==================== 上游模型导入相关 ====================
|
||||
|
||||
/**
|
||||
* 上游模型(从提供商 API 获取的原始模型)
|
||||
*/
|
||||
export interface UpstreamModel {
|
||||
id: string
|
||||
owned_by?: string
|
||||
display_name?: string
|
||||
api_format?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 导入成功的模型信息
|
||||
*/
|
||||
export interface ImportFromUpstreamSuccessItem {
|
||||
model_id: string
|
||||
global_model_id: string
|
||||
global_model_name: string
|
||||
provider_model_id: string
|
||||
created_global_model: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 导入失败的模型信息
|
||||
*/
|
||||
export interface ImportFromUpstreamErrorItem {
|
||||
model_id: string
|
||||
error: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 从上游提供商导入模型响应
|
||||
*/
|
||||
export interface ImportFromUpstreamResponse {
|
||||
success: ImportFromUpstreamSuccessItem[]
|
||||
errors: ImportFromUpstreamErrorItem[]
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, useSlots, type Component } from 'vue'
|
||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||
|
||||
// Props 定义
|
||||
const props = defineProps<{
|
||||
@@ -157,4 +158,14 @@ const maxWidthClass = computed(() => {
|
||||
const containerZIndex = computed(() => props.zIndex || 60)
|
||||
const backdropZIndex = computed(() => props.zIndex || 60)
|
||||
const contentZIndex = computed(() => (props.zIndex || 60) + 10)
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (isOpen.value) {
|
||||
handleClose()
|
||||
}
|
||||
}, {
|
||||
disableOnInput: true,
|
||||
once: false
|
||||
})
|
||||
</script>
|
||||
|
||||
80
frontend/src/composables/useEscapeKey.ts
Normal file
80
frontend/src/composables/useEscapeKey.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { onMounted, onUnmounted, ref } from 'vue'
|
||||
|
||||
/**
|
||||
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
||||
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
||||
*
|
||||
* @param callback - 按 ESC 键时执行的回调函数
|
||||
* @param options - 配置选项
|
||||
*/
|
||||
export function useEscapeKey(
|
||||
callback: () => void,
|
||||
options: {
|
||||
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
||||
disableOnInput?: boolean
|
||||
/** 是否只监听一次,默认 false */
|
||||
once?: boolean
|
||||
} = {}
|
||||
) {
|
||||
const { disableOnInput = true, once = false } = options
|
||||
const isActive = ref(true)
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
// 只处理 ESC 键
|
||||
if (event.key !== 'Escape') return
|
||||
|
||||
// 检查组件是否还活跃
|
||||
if (!isActive.value) return
|
||||
|
||||
// 如果配置了在输入框获得焦点时禁用,则检查当前焦点元素
|
||||
if (disableOnInput) {
|
||||
const activeElement = document.activeElement
|
||||
const isInputElement = activeElement && (
|
||||
activeElement.tagName === 'INPUT' ||
|
||||
activeElement.tagName === 'TEXTAREA' ||
|
||||
activeElement.tagName === 'SELECT' ||
|
||||
activeElement.contentEditable === 'true' ||
|
||||
activeElement.getAttribute('role') === 'textbox' ||
|
||||
activeElement.getAttribute('role') === 'combobox'
|
||||
)
|
||||
|
||||
// 如果焦点在输入框中,不处理 ESC 键
|
||||
if (isInputElement) return
|
||||
}
|
||||
|
||||
// 执行回调
|
||||
callback()
|
||||
|
||||
// 移除当前元素的焦点,避免残留样式
|
||||
if (document.activeElement instanceof HTMLElement) {
|
||||
document.activeElement.blur()
|
||||
}
|
||||
|
||||
// 如果只监听一次,则移除监听器
|
||||
if (once) {
|
||||
removeEventListener()
|
||||
}
|
||||
}
|
||||
|
||||
function addEventListener() {
|
||||
document.addEventListener('keydown', handleKeyDown)
|
||||
}
|
||||
|
||||
function removeEventListener() {
|
||||
document.removeEventListener('keydown', handleKeyDown)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
addEventListener()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
isActive.value = false
|
||||
removeEventListener()
|
||||
})
|
||||
|
||||
return {
|
||||
addEventListener,
|
||||
removeEventListener
|
||||
}
|
||||
}
|
||||
@@ -698,6 +698,7 @@ import {
|
||||
Layers,
|
||||
BarChart3
|
||||
} from 'lucide-vue-next'
|
||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
@@ -833,6 +834,16 @@ watch(() => props.open, (newOpen) => {
|
||||
detailTab.value = 'basic'
|
||||
}
|
||||
})
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (props.open) {
|
||||
handleClose()
|
||||
}
|
||||
}, {
|
||||
disableOnInput: true,
|
||||
once: false
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -31,29 +31,46 @@
|
||||
|
||||
<!-- 左右对比布局 -->
|
||||
<div class="flex gap-2 items-stretch">
|
||||
<!-- 左侧:可添加的模型 -->
|
||||
<!-- 左侧:可添加的模型(分组折叠) -->
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-2">
|
||||
<p class="text-sm font-medium">
|
||||
可添加
|
||||
</p>
|
||||
<Button
|
||||
v-if="availableModels.length > 0"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 px-2 text-xs"
|
||||
@click="toggleSelectAllLeft"
|
||||
>
|
||||
{{ isAllLeftSelected ? '取消全选' : '全选' }}
|
||||
</Button>
|
||||
<div class="flex items-center justify-between gap-2">
|
||||
<p class="text-sm font-medium shrink-0">
|
||||
可添加
|
||||
</p>
|
||||
<div class="flex-1 relative">
|
||||
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
|
||||
<Input
|
||||
v-model="searchQuery"
|
||||
placeholder="搜索模型..."
|
||||
class="pl-7 h-7 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
<button
|
||||
v-if="upstreamModelsLoaded"
|
||||
type="button"
|
||||
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||
title="刷新上游模型"
|
||||
:disabled="fetchingUpstreamModels"
|
||||
@click="fetchUpstreamModels(true)"
|
||||
>
|
||||
{{ availableModels.length }} 个
|
||||
</Badge>
|
||||
<RefreshCw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="{ 'animate-spin': fetchingUpstreamModels }"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
v-else-if="!fetchingUpstreamModels"
|
||||
type="button"
|
||||
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||
title="从提供商获取模型"
|
||||
@click="fetchUpstreamModels"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
<Loader2
|
||||
v-else
|
||||
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||
<div
|
||||
@@ -63,7 +80,7 @@
|
||||
<Loader2 class="w-6 h-6 animate-spin text-primary" />
|
||||
</div>
|
||||
<div
|
||||
v-else-if="availableModels.length === 0"
|
||||
v-else-if="totalAvailableCount === 0 && !upstreamModelsLoaded"
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||
>
|
||||
<Layers class="w-10 h-10 mb-2 opacity-30" />
|
||||
@@ -73,37 +90,142 @@
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="p-2 space-y-1"
|
||||
class="p-2 space-y-2"
|
||||
>
|
||||
<!-- 全局模型折叠组 -->
|
||||
<div
|
||||
v-for="model in availableModels"
|
||||
:key="model.id"
|
||||
class="flex items-center gap-2 p-2 rounded-lg border transition-colors"
|
||||
:class="selectedLeftIds.includes(model.id)
|
||||
? 'border-primary bg-primary/10'
|
||||
: 'hover:bg-muted/50 cursor-pointer'"
|
||||
@click="toggleLeftSelection(model.id)"
|
||||
v-if="availableGlobalModels.length > 0 || !upstreamModelsLoaded"
|
||||
class="border rounded-lg overflow-hidden"
|
||||
>
|
||||
<Checkbox
|
||||
:checked="selectedLeftIds.includes(model.id)"
|
||||
@update:checked="toggleLeftSelection(model.id)"
|
||||
@click.stop
|
||||
/>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="font-medium text-sm truncate">
|
||||
{{ model.display_name }}
|
||||
</p>
|
||||
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||
{{ model.name }}
|
||||
</p>
|
||||
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
|
||||
<button
|
||||
type="button"
|
||||
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
|
||||
@click="toggleGroupCollapse('global')"
|
||||
>
|
||||
<ChevronDown
|
||||
class="w-4 h-4 transition-transform shrink-0"
|
||||
:class="collapsedGroups.has('global') ? '-rotate-90' : ''"
|
||||
/>
|
||||
<span class="text-xs font-medium">
|
||||
全局模型
|
||||
</span>
|
||||
<span class="text-xs text-muted-foreground">
|
||||
({{ availableGlobalModels.length }})
|
||||
</span>
|
||||
</button>
|
||||
<button
|
||||
v-if="availableGlobalModels.length > 0"
|
||||
type="button"
|
||||
class="text-xs text-primary hover:underline shrink-0"
|
||||
@click.stop="selectAllGlobalModels"
|
||||
>
|
||||
{{ isAllGlobalModelsSelected ? '取消' : '全选' }}
|
||||
</button>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.is_active ? 'outline' : 'secondary'"
|
||||
:class="model.is_active ? 'text-green-600 border-green-500/60' : ''"
|
||||
class="text-xs shrink-0"
|
||||
<div
|
||||
v-show="!collapsedGroups.has('global')"
|
||||
class="p-2 space-y-1 border-t"
|
||||
>
|
||||
{{ model.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
<div
|
||||
v-if="availableGlobalModels.length === 0"
|
||||
class="py-4 text-center text-xs text-muted-foreground"
|
||||
>
|
||||
所有全局模型均已关联
|
||||
</div>
|
||||
<div
|
||||
v-for="model in availableGlobalModels"
|
||||
v-else
|
||||
:key="model.id"
|
||||
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
|
||||
:class="selectedGlobalModelIds.includes(model.id)
|
||||
? 'border-primary bg-primary/10'
|
||||
: 'hover:bg-muted/50'"
|
||||
@click="toggleGlobalModelSelection(model.id)"
|
||||
>
|
||||
<Checkbox
|
||||
:checked="selectedGlobalModelIds.includes(model.id)"
|
||||
@update:checked="toggleGlobalModelSelection(model.id)"
|
||||
@click.stop
|
||||
/>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="font-medium text-sm truncate">
|
||||
{{ model.display_name }}
|
||||
</p>
|
||||
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||
{{ model.name }}
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.is_active ? 'outline' : 'secondary'"
|
||||
:class="model.is_active ? 'text-green-600 border-green-500/60' : ''"
|
||||
class="text-xs shrink-0"
|
||||
>
|
||||
{{ model.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 从提供商获取的模型折叠组 -->
|
||||
<div
|
||||
v-for="group in upstreamModelGroups"
|
||||
:key="group.api_format"
|
||||
class="border rounded-lg overflow-hidden"
|
||||
>
|
||||
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
|
||||
<button
|
||||
type="button"
|
||||
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
|
||||
@click="toggleGroupCollapse(group.api_format)"
|
||||
>
|
||||
<ChevronDown
|
||||
class="w-4 h-4 transition-transform shrink-0"
|
||||
:class="collapsedGroups.has(group.api_format) ? '-rotate-90' : ''"
|
||||
/>
|
||||
<span class="text-xs font-medium">
|
||||
{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}
|
||||
</span>
|
||||
<span class="text-xs text-muted-foreground">
|
||||
({{ group.models.length }})
|
||||
</span>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs text-primary hover:underline shrink-0"
|
||||
@click.stop="selectAllUpstreamModels(group.api_format)"
|
||||
>
|
||||
{{ isUpstreamGroupAllSelected(group.api_format) ? '取消' : '全选' }}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-show="!collapsedGroups.has(group.api_format)"
|
||||
class="p-2 space-y-1 border-t"
|
||||
>
|
||||
<div
|
||||
v-for="model in group.models"
|
||||
:key="model.id"
|
||||
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
|
||||
:class="selectedUpstreamModelIds.includes(model.id)
|
||||
? 'border-primary bg-primary/10'
|
||||
: 'hover:bg-muted/50'"
|
||||
@click="toggleUpstreamModelSelection(model.id)"
|
||||
>
|
||||
<Checkbox
|
||||
:checked="selectedUpstreamModelIds.includes(model.id)"
|
||||
@update:checked="toggleUpstreamModelSelection(model.id)"
|
||||
@click.stop
|
||||
/>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="font-medium text-sm truncate">
|
||||
{{ model.id }}
|
||||
</p>
|
||||
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||
{{ model.owned_by || model.id }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -115,8 +237,8 @@
|
||||
variant="outline"
|
||||
size="sm"
|
||||
class="w-9 h-8"
|
||||
:class="selectedLeftIds.length > 0 && !submittingAdd ? 'border-primary' : ''"
|
||||
:disabled="selectedLeftIds.length === 0 || submittingAdd"
|
||||
:class="totalSelectedCount > 0 && !submittingAdd ? 'border-primary' : ''"
|
||||
:disabled="totalSelectedCount === 0 || submittingAdd"
|
||||
title="添加选中"
|
||||
@click="batchAddSelected"
|
||||
>
|
||||
@@ -127,7 +249,7 @@
|
||||
<ChevronRight
|
||||
v-else
|
||||
class="w-6 h-6 stroke-[3]"
|
||||
:class="selectedLeftIds.length > 0 && !submittingAdd ? 'text-primary' : ''"
|
||||
:class="totalSelectedCount > 0 && !submittingAdd ? 'text-primary' : ''"
|
||||
/>
|
||||
</Button>
|
||||
<Button
|
||||
@@ -154,26 +276,18 @@
|
||||
<!-- 右侧:已添加的模型 -->
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-2">
|
||||
<p class="text-sm font-medium">
|
||||
已添加
|
||||
</p>
|
||||
<Button
|
||||
v-if="existingModels.length > 0"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 px-2 text-xs"
|
||||
@click="toggleSelectAllRight"
|
||||
>
|
||||
{{ isAllRightSelected ? '取消全选' : '全选' }}
|
||||
</Button>
|
||||
</div>
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
<p class="text-sm font-medium">
|
||||
已添加
|
||||
</p>
|
||||
<Button
|
||||
v-if="existingModels.length > 0"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 px-2 text-xs"
|
||||
@click="toggleSelectAllRight"
|
||||
>
|
||||
{{ existingModels.length }} 个
|
||||
</Badge>
|
||||
{{ isAllRightSelected ? '取消' : '全选' }}
|
||||
</Button>
|
||||
</div>
|
||||
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||
<div
|
||||
@@ -238,11 +352,12 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import { Layers, Loader2, ChevronRight, ChevronLeft } from 'lucide-vue-next'
|
||||
import { Layers, Loader2, ChevronRight, ChevronLeft, ChevronDown, Zap, RefreshCw, Search } from 'lucide-vue-next'
|
||||
import Dialog from '@/components/ui/dialog/Dialog.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import Checkbox from '@/components/ui/checkbox.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { parseApiError } from '@/utils/errorParser'
|
||||
import {
|
||||
@@ -253,8 +368,13 @@ import {
|
||||
getProviderModels,
|
||||
batchAssignModelsToProvider,
|
||||
deleteModel,
|
||||
importModelsFromUpstream,
|
||||
API_FORMAT_LABELS,
|
||||
type Model
|
||||
} from '@/api/endpoints'
|
||||
import { useUpstreamModelsCache, type UpstreamModel } from '../composables/useUpstreamModelsCache'
|
||||
|
||||
const { fetchModels: fetchCachedModels, clearCache, getCachedModels } = useUpstreamModelsCache()
|
||||
|
||||
const props = defineProps<{
|
||||
open: boolean
|
||||
@@ -274,17 +394,27 @@ const { error: showError, success } = useToast()
|
||||
const loadingGlobalModels = ref(false)
|
||||
const submittingAdd = ref(false)
|
||||
const submittingRemove = ref(false)
|
||||
const fetchingUpstreamModels = ref(false)
|
||||
const upstreamModelsLoaded = ref(false)
|
||||
|
||||
// 数据
|
||||
const allGlobalModels = ref<GlobalModelResponse[]>([])
|
||||
const existingModels = ref<Model[]>([])
|
||||
const upstreamModels = ref<UpstreamModel[]>([])
|
||||
|
||||
// 选择状态
|
||||
const selectedLeftIds = ref<string[]>([])
|
||||
const selectedGlobalModelIds = ref<string[]>([])
|
||||
const selectedUpstreamModelIds = ref<string[]>([])
|
||||
const selectedRightIds = ref<string[]>([])
|
||||
|
||||
// 计算可添加的模型(排除已关联的)
|
||||
const availableModels = computed(() => {
|
||||
// 折叠状态
|
||||
const collapsedGroups = ref<Set<string>>(new Set())
|
||||
|
||||
// 搜索状态
|
||||
const searchQuery = ref('')
|
||||
|
||||
// 计算可添加的全局模型(排除已关联的)
|
||||
const availableGlobalModelsBase = computed(() => {
|
||||
const existingGlobalModelIds = new Set(
|
||||
existingModels.value
|
||||
.filter(m => m.global_model_id)
|
||||
@@ -293,31 +423,123 @@ const availableModels = computed(() => {
|
||||
return allGlobalModels.value.filter(m => !existingGlobalModelIds.has(m.id))
|
||||
})
|
||||
|
||||
// 全选状态
|
||||
const isAllLeftSelected = computed(() =>
|
||||
availableModels.value.length > 0 &&
|
||||
selectedLeftIds.value.length === availableModels.value.length
|
||||
)
|
||||
// 搜索过滤后的全局模型
|
||||
const availableGlobalModels = computed(() => {
|
||||
if (!searchQuery.value.trim()) return availableGlobalModelsBase.value
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
return availableGlobalModelsBase.value.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name.toLowerCase().includes(query)
|
||||
)
|
||||
})
|
||||
|
||||
// 计算可添加的上游模型(排除已关联的)
|
||||
const availableUpstreamModelsBase = computed(() => {
|
||||
const existingModelNames = new Set(
|
||||
existingModels.value.map(m => m.provider_model_name)
|
||||
)
|
||||
return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
|
||||
})
|
||||
|
||||
// 搜索过滤后的上游模型
|
||||
const availableUpstreamModels = computed(() => {
|
||||
if (!searchQuery.value.trim()) return availableUpstreamModelsBase.value
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
return availableUpstreamModelsBase.value.filter(m =>
|
||||
m.id.toLowerCase().includes(query) ||
|
||||
(m.owned_by && m.owned_by.toLowerCase().includes(query))
|
||||
)
|
||||
})
|
||||
|
||||
// 按 API 格式分组的上游模型
|
||||
const upstreamModelGroups = computed(() => {
|
||||
const groups: Record<string, UpstreamModel[]> = {}
|
||||
|
||||
for (const model of availableUpstreamModels.value) {
|
||||
const format = model.api_format || 'unknown'
|
||||
if (!groups[format]) {
|
||||
groups[format] = []
|
||||
}
|
||||
groups[format].push(model)
|
||||
}
|
||||
|
||||
// 按 API_FORMAT_LABELS 的顺序排序
|
||||
const order = Object.keys(API_FORMAT_LABELS)
|
||||
return Object.entries(groups)
|
||||
.map(([api_format, models]) => ({ api_format, models }))
|
||||
.sort((a, b) => {
|
||||
const aIndex = order.indexOf(a.api_format)
|
||||
const bIndex = order.indexOf(b.api_format)
|
||||
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
|
||||
if (aIndex === -1) return 1
|
||||
if (bIndex === -1) return -1
|
||||
return aIndex - bIndex
|
||||
})
|
||||
})
|
||||
|
||||
// 总可添加数量
|
||||
const totalAvailableCount = computed(() => {
|
||||
return availableGlobalModels.value.length + availableUpstreamModels.value.length
|
||||
})
|
||||
|
||||
// 总选中数量
|
||||
const totalSelectedCount = computed(() => {
|
||||
return selectedGlobalModelIds.value.length + selectedUpstreamModelIds.value.length
|
||||
})
|
||||
|
||||
// 全选状态
|
||||
const isAllRightSelected = computed(() =>
|
||||
existingModels.value.length > 0 &&
|
||||
selectedRightIds.value.length === existingModels.value.length
|
||||
)
|
||||
|
||||
// 全局模型是否全选
|
||||
const isAllGlobalModelsSelected = computed(() => {
|
||||
if (availableGlobalModels.value.length === 0) return false
|
||||
return availableGlobalModels.value.every(m => selectedGlobalModelIds.value.includes(m.id))
|
||||
})
|
||||
|
||||
// 检查某个上游组是否全选
|
||||
function isUpstreamGroupAllSelected(apiFormat: string): boolean {
|
||||
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
|
||||
if (!group || group.models.length === 0) return false
|
||||
return group.models.every(m => selectedUpstreamModelIds.value.includes(m.id))
|
||||
}
|
||||
|
||||
// 监听打开状态
|
||||
watch(() => props.open, async (isOpen) => {
|
||||
if (isOpen && props.providerId) {
|
||||
await loadData()
|
||||
} else {
|
||||
// 重置状态
|
||||
selectedLeftIds.value = []
|
||||
selectedGlobalModelIds.value = []
|
||||
selectedUpstreamModelIds.value = []
|
||||
selectedRightIds.value = []
|
||||
upstreamModels.value = []
|
||||
upstreamModelsLoaded.value = false
|
||||
collapsedGroups.value = new Set()
|
||||
searchQuery.value = ''
|
||||
}
|
||||
})
|
||||
|
||||
// 加载数据
|
||||
async function loadData() {
|
||||
await Promise.all([loadGlobalModels(), loadExistingModels()])
|
||||
// 默认折叠全局模型组
|
||||
collapsedGroups.value = new Set(['global'])
|
||||
|
||||
// 检查缓存,如果有缓存数据则直接使用
|
||||
const cachedModels = getCachedModels(props.providerId)
|
||||
if (cachedModels) {
|
||||
upstreamModels.value = cachedModels
|
||||
upstreamModelsLoaded.value = true
|
||||
// 折叠所有上游模型组
|
||||
for (const model of cachedModels) {
|
||||
if (model.api_format) {
|
||||
collapsedGroups.value.add(model.api_format)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加载全局模型列表
|
||||
@@ -342,13 +564,91 @@ async function loadExistingModels() {
|
||||
}
|
||||
}
|
||||
|
||||
// 切换左侧选择
|
||||
function toggleLeftSelection(id: string) {
|
||||
const index = selectedLeftIds.value.indexOf(id)
|
||||
if (index === -1) {
|
||||
selectedLeftIds.value.push(id)
|
||||
// 从提供商获取模型
|
||||
async function fetchUpstreamModels(forceRefresh = false) {
|
||||
if (forceRefresh) {
|
||||
clearCache(props.providerId)
|
||||
}
|
||||
|
||||
try {
|
||||
fetchingUpstreamModels.value = true
|
||||
const result = await fetchCachedModels(props.providerId, forceRefresh)
|
||||
if (result) {
|
||||
if (result.error) {
|
||||
showError(result.error, '错误')
|
||||
} else {
|
||||
upstreamModels.value = result.models
|
||||
upstreamModelsLoaded.value = true
|
||||
// 折叠所有上游模型组
|
||||
const allGroups = new Set(collapsedGroups.value)
|
||||
for (const model of result.models) {
|
||||
if (model.api_format) {
|
||||
allGroups.add(model.api_format)
|
||||
}
|
||||
}
|
||||
collapsedGroups.value = allGroups
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
fetchingUpstreamModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 切换折叠状态
|
||||
function toggleGroupCollapse(group: string) {
|
||||
if (collapsedGroups.value.has(group)) {
|
||||
collapsedGroups.value.delete(group)
|
||||
} else {
|
||||
selectedLeftIds.value.splice(index, 1)
|
||||
collapsedGroups.value.add(group)
|
||||
}
|
||||
// 触发响应式更新
|
||||
collapsedGroups.value = new Set(collapsedGroups.value)
|
||||
}
|
||||
|
||||
// 切换全局模型选择
|
||||
function toggleGlobalModelSelection(id: string) {
|
||||
const index = selectedGlobalModelIds.value.indexOf(id)
|
||||
if (index === -1) {
|
||||
selectedGlobalModelIds.value.push(id)
|
||||
} else {
|
||||
selectedGlobalModelIds.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 切换上游模型选择
|
||||
function toggleUpstreamModelSelection(id: string) {
|
||||
const index = selectedUpstreamModelIds.value.indexOf(id)
|
||||
if (index === -1) {
|
||||
selectedUpstreamModelIds.value.push(id)
|
||||
} else {
|
||||
selectedUpstreamModelIds.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 全选全局模型
|
||||
function selectAllGlobalModels() {
|
||||
const allIds = availableGlobalModels.value.map(m => m.id)
|
||||
const allSelected = allIds.every(id => selectedGlobalModelIds.value.includes(id))
|
||||
if (allSelected) {
|
||||
selectedGlobalModelIds.value = selectedGlobalModelIds.value.filter(id => !allIds.includes(id))
|
||||
} else {
|
||||
const newIds = allIds.filter(id => !selectedGlobalModelIds.value.includes(id))
|
||||
selectedGlobalModelIds.value.push(...newIds)
|
||||
}
|
||||
}
|
||||
|
||||
// 全选某个 API 格式的上游模型
|
||||
function selectAllUpstreamModels(apiFormat: string) {
|
||||
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
|
||||
if (!group) return
|
||||
|
||||
const allIds = group.models.map(m => m.id)
|
||||
const allSelected = allIds.every(id => selectedUpstreamModelIds.value.includes(id))
|
||||
if (allSelected) {
|
||||
selectedUpstreamModelIds.value = selectedUpstreamModelIds.value.filter(id => !allIds.includes(id))
|
||||
} else {
|
||||
const newIds = allIds.filter(id => !selectedUpstreamModelIds.value.includes(id))
|
||||
selectedUpstreamModelIds.value.push(...newIds)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,15 +662,6 @@ function toggleRightSelection(id: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 全选/取消全选左侧
|
||||
function toggleSelectAllLeft() {
|
||||
if (isAllLeftSelected.value) {
|
||||
selectedLeftIds.value = []
|
||||
} else {
|
||||
selectedLeftIds.value = availableModels.value.map(m => m.id)
|
||||
}
|
||||
}
|
||||
|
||||
// 全选/取消全选右侧
|
||||
function toggleSelectAllRight() {
|
||||
if (isAllRightSelected.value) {
|
||||
@@ -382,22 +673,41 @@ function toggleSelectAllRight() {
|
||||
|
||||
// 批量添加选中的模型
|
||||
async function batchAddSelected() {
|
||||
if (selectedLeftIds.value.length === 0) return
|
||||
if (totalSelectedCount.value === 0) return
|
||||
|
||||
try {
|
||||
submittingAdd.value = true
|
||||
const result = await batchAssignModelsToProvider(props.providerId, selectedLeftIds.value)
|
||||
let totalSuccess = 0
|
||||
const allErrors: string[] = []
|
||||
|
||||
if (result.success.length > 0) {
|
||||
success(`成功添加 ${result.success.length} 个模型`)
|
||||
// 处理全局模型
|
||||
if (selectedGlobalModelIds.value.length > 0) {
|
||||
const result = await batchAssignModelsToProvider(props.providerId, selectedGlobalModelIds.value)
|
||||
totalSuccess += result.success.length
|
||||
if (result.errors.length > 0) {
|
||||
allErrors.push(...result.errors.map(e => e.error))
|
||||
}
|
||||
}
|
||||
|
||||
if (result.errors.length > 0) {
|
||||
const errorMessages = result.errors.map(e => e.error).join(', ')
|
||||
showError(`部分模型添加失败: ${errorMessages}`, '警告')
|
||||
// 处理上游模型(调用 import-from-upstream API)
|
||||
if (selectedUpstreamModelIds.value.length > 0) {
|
||||
const result = await importModelsFromUpstream(props.providerId, selectedUpstreamModelIds.value)
|
||||
totalSuccess += result.success.length
|
||||
if (result.errors.length > 0) {
|
||||
allErrors.push(...result.errors.map(e => e.error))
|
||||
}
|
||||
}
|
||||
|
||||
selectedLeftIds.value = []
|
||||
if (totalSuccess > 0) {
|
||||
success(`成功添加 ${totalSuccess} 个模型`)
|
||||
}
|
||||
|
||||
if (allErrors.length > 0) {
|
||||
showError(`部分模型添加失败: ${allErrors.slice(0, 3).join(', ')}${allErrors.length > 3 ? '...' : ''}`, '警告')
|
||||
}
|
||||
|
||||
selectedGlobalModelIds.value = []
|
||||
selectedUpstreamModelIds.value = []
|
||||
await loadExistingModels()
|
||||
emit('changed')
|
||||
} catch (err: any) {
|
||||
|
||||
@@ -260,6 +260,7 @@ import {
|
||||
updateEndpointKey,
|
||||
getAllCapabilities,
|
||||
type EndpointAPIKey,
|
||||
type EndpointAPIKeyUpdate,
|
||||
type ProviderEndpoint,
|
||||
type CapabilityDefinition
|
||||
} from '@/api/endpoints'
|
||||
@@ -386,10 +387,11 @@ function loadKeyData() {
|
||||
api_key: '',
|
||||
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
|
||||
internal_priority: props.editingKey.internal_priority ?? 50,
|
||||
max_concurrent: props.editingKey.max_concurrent || undefined,
|
||||
rate_limit: props.editingKey.rate_limit || undefined,
|
||||
daily_limit: props.editingKey.daily_limit || undefined,
|
||||
monthly_limit: props.editingKey.monthly_limit || undefined,
|
||||
// 保留原始的 null/undefined 状态,null 表示自适应模式
|
||||
max_concurrent: props.editingKey.max_concurrent ?? undefined,
|
||||
rate_limit: props.editingKey.rate_limit ?? undefined,
|
||||
daily_limit: props.editingKey.daily_limit ?? undefined,
|
||||
monthly_limit: props.editingKey.monthly_limit ?? undefined,
|
||||
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
|
||||
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
|
||||
note: props.editingKey.note || '',
|
||||
@@ -439,12 +441,17 @@ async function handleSave() {
|
||||
saving.value = true
|
||||
try {
|
||||
if (props.editingKey) {
|
||||
// 更新
|
||||
const updateData: any = {
|
||||
// 更新模式
|
||||
// 注意:max_concurrent 需要显式发送 null 来切换到自适应模式
|
||||
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
|
||||
const updateData: EndpointAPIKeyUpdate = {
|
||||
name: form.value.name,
|
||||
rate_multiplier: form.value.rate_multiplier,
|
||||
internal_priority: form.value.internal_priority,
|
||||
max_concurrent: form.value.max_concurrent,
|
||||
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
|
||||
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
|
||||
// 其他限制字段(rate_limit 等)不支持"清空"操作,undefined 会被 JSON 忽略即不更新
|
||||
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
|
||||
rate_limit: form.value.rate_limit,
|
||||
daily_limit: form.value.daily_limit,
|
||||
monthly_limit: form.value.monthly_limit,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 别名列表 -->
|
||||
<!-- 映射列表 -->
|
||||
<div class="space-y-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-sm font-medium">名称映射</Label>
|
||||
@@ -92,7 +92,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 别名输入框 -->
|
||||
<!-- 映射输入框 -->
|
||||
<Input
|
||||
v-model="alias.name"
|
||||
placeholder="映射名称,如 Claude-Sonnet-4.5"
|
||||
@@ -184,9 +184,9 @@ const editingPriorityIndex = ref<number | null>(null)
|
||||
// 监听 open 变化
|
||||
watch(() => props.open, (newOpen) => {
|
||||
if (newOpen && props.model) {
|
||||
// 加载现有别名配置
|
||||
if (props.model.provider_model_aliases && Array.isArray(props.model.provider_model_aliases)) {
|
||||
aliases.value = JSON.parse(JSON.stringify(props.model.provider_model_aliases))
|
||||
// 加载现有映射配置
|
||||
if (props.model.provider_model_mappings && Array.isArray(props.model.provider_model_mappings)) {
|
||||
aliases.value = JSON.parse(JSON.stringify(props.model.provider_model_mappings))
|
||||
} else {
|
||||
aliases.value = []
|
||||
}
|
||||
@@ -197,16 +197,16 @@ watch(() => props.open, (newOpen) => {
|
||||
}
|
||||
})
|
||||
|
||||
// 添加别名
|
||||
// 添加映射
|
||||
function addAlias() {
|
||||
// 新别名优先级为当前最大优先级 + 1,或者默认为 1
|
||||
// 新映射优先级为当前最大优先级 + 1,或者默认为 1
|
||||
const maxPriority = aliases.value.length > 0
|
||||
? Math.max(...aliases.value.map(a => a.priority))
|
||||
: 0
|
||||
aliases.value.push({ name: '', priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
// 移除别名
|
||||
// 移除映射
|
||||
function removeAlias(index: number) {
|
||||
aliases.value.splice(index, 1)
|
||||
}
|
||||
@@ -244,7 +244,7 @@ function handleDrop(targetIndex: number) {
|
||||
const items = [...aliases.value]
|
||||
const draggedItem = items[dragIndex]
|
||||
|
||||
// 记录每个别名的原始优先级(在修改前)
|
||||
// 记录每个映射的原始优先级(在修改前)
|
||||
const originalPriorityMap = new Map<number, number>()
|
||||
items.forEach((alias, idx) => {
|
||||
originalPriorityMap.set(idx, alias.priority)
|
||||
@@ -255,7 +255,7 @@ function handleDrop(targetIndex: number) {
|
||||
items.splice(targetIndex, 0, draggedItem)
|
||||
|
||||
// 按新顺序为每个组分配新的优先级
|
||||
// 同组的别名保持相同的优先级(被拖动的别名单独成组)
|
||||
// 同组的映射保持相同的优先级(被拖动的映射单独成组)
|
||||
const groupNewPriority = new Map<number, number>() // 原优先级 -> 新优先级
|
||||
let currentPriority = 1
|
||||
|
||||
@@ -263,12 +263,12 @@ function handleDrop(targetIndex: number) {
|
||||
const draggedOriginalPriority = originalPriorityMap.get(dragIndex)!
|
||||
|
||||
items.forEach((alias, newIdx) => {
|
||||
// 找到这个别名在原数组中的索引
|
||||
// 找到这个映射在原数组中的索引
|
||||
const originalIdx = aliases.value.findIndex(a => a === alias)
|
||||
const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority
|
||||
|
||||
if (alias === draggedItem) {
|
||||
// 被拖动的别名是独立的新组,获得当前优先级
|
||||
// 被拖动的映射是独立的新组,获得当前优先级
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
} else {
|
||||
@@ -318,11 +318,11 @@ async function handleSubmit() {
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
// 过滤掉空的别名
|
||||
// 过滤掉空的映射
|
||||
const validAliases = aliases.value.filter(a => a.name.trim())
|
||||
|
||||
await updateModel(props.providerId, props.model.id, {
|
||||
provider_model_aliases: validAliases.length > 0 ? validAliases : null
|
||||
provider_model_mappings: validAliases.length > 0 ? validAliases : null
|
||||
})
|
||||
|
||||
showSuccess('映射配置已保存')
|
||||
|
||||
@@ -0,0 +1,777 @@
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
:title="editingGroup ? '编辑模型映射' : '添加模型映射'"
|
||||
:description="editingGroup ? '修改映射配置' : '为模型添加新的名称映射'"
|
||||
:icon="Tag"
|
||||
size="4xl"
|
||||
@update:model-value="$emit('update:open', $event)"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<!-- 第一行:目标模型 | 作用域 -->
|
||||
<div class="flex gap-4">
|
||||
<!-- 目标模型 -->
|
||||
<div class="flex-1 space-y-1.5">
|
||||
<Label class="text-xs">目标模型</Label>
|
||||
<Select
|
||||
v-model:open="modelSelectOpen"
|
||||
:model-value="formData.modelId"
|
||||
:disabled="!!editingGroup"
|
||||
@update:model-value="formData.modelId = $event"
|
||||
>
|
||||
<SelectTrigger class="h-9">
|
||||
<SelectValue placeholder="请选择模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in models"
|
||||
:key="model.id"
|
||||
:value="model.id"
|
||||
>
|
||||
{{ model.global_model_display_name || model.provider_model_name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<!-- 作用域 -->
|
||||
<div class="flex-1 space-y-1.5">
|
||||
<Label class="text-xs">作用域 <span class="text-muted-foreground font-normal">(不选则适用全部)</span></Label>
|
||||
<div
|
||||
v-if="providerApiFormats.length > 0"
|
||||
class="flex flex-wrap gap-1.5 p-2 rounded-md border bg-muted/30 min-h-[36px]"
|
||||
>
|
||||
<button
|
||||
v-for="format in providerApiFormats"
|
||||
:key="format"
|
||||
type="button"
|
||||
class="px-2.5 py-0.5 rounded text-xs font-medium transition-colors"
|
||||
:class="[
|
||||
formData.apiFormats.includes(format)
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: 'bg-background border border-border hover:bg-muted'
|
||||
]"
|
||||
@click="toggleApiFormat(format)"
|
||||
>
|
||||
{{ API_FORMAT_LABELS[format] || format }}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="h-9 flex items-center text-xs text-muted-foreground"
|
||||
>
|
||||
无可用格式
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 第二行:两栏布局 -->
|
||||
<div class="flex gap-4 items-stretch">
|
||||
<!-- 左侧:上游模型列表 -->
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="flex items-center justify-between gap-2">
|
||||
<span class="text-sm font-medium shrink-0">
|
||||
上游模型
|
||||
</span>
|
||||
<div class="flex-1 relative">
|
||||
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
|
||||
<Input
|
||||
v-model="upstreamModelSearch"
|
||||
placeholder="搜索模型..."
|
||||
class="pl-7 h-7 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
v-if="upstreamModelsLoaded"
|
||||
type="button"
|
||||
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||
title="刷新列表"
|
||||
:disabled="refreshingUpstreamModels"
|
||||
@click="refreshUpstreamModels"
|
||||
>
|
||||
<RefreshCw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="{ 'animate-spin': refreshingUpstreamModels }"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
v-else-if="!fetchingUpstreamModels"
|
||||
type="button"
|
||||
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||
title="获取上游模型列表"
|
||||
@click="fetchUpstreamModels"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
<Loader2
|
||||
v-else
|
||||
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||
<template v-if="upstreamModelsLoaded">
|
||||
<div
|
||||
v-if="groupedAvailableUpstreamModels.length === 0"
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||
>
|
||||
<Zap class="w-10 h-10 mb-2 opacity-30" />
|
||||
<p class="text-sm">
|
||||
{{ upstreamModelSearch ? '没有匹配的模型' : '所有模型已添加' }}
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="p-2 space-y-2"
|
||||
>
|
||||
<!-- 按分组显示(可折叠) -->
|
||||
<div
|
||||
v-for="group in groupedAvailableUpstreamModels"
|
||||
:key="group.api_format"
|
||||
class="border rounded-lg overflow-hidden"
|
||||
>
|
||||
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
|
||||
<button
|
||||
type="button"
|
||||
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
|
||||
@click="toggleGroupCollapse(group.api_format)"
|
||||
>
|
||||
<ChevronDown
|
||||
class="w-4 h-4 transition-transform shrink-0"
|
||||
:class="collapsedGroups.has(group.api_format) ? '-rotate-90' : ''"
|
||||
/>
|
||||
<span class="text-xs font-medium">
|
||||
{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}
|
||||
</span>
|
||||
<span class="text-xs text-muted-foreground">
|
||||
({{ group.models.length }})
|
||||
</span>
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-show="!collapsedGroups.has(group.api_format)"
|
||||
class="p-2 space-y-1 border-t"
|
||||
>
|
||||
<div
|
||||
v-for="model in group.models"
|
||||
:key="model.id"
|
||||
class="flex items-center gap-2 p-2 rounded-lg border transition-colors hover:bg-muted/30"
|
||||
:title="model.id"
|
||||
>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="font-medium text-sm truncate">
|
||||
{{ model.id }}
|
||||
</p>
|
||||
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||
{{ model.owned_by || model.id }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
class="p-1 hover:bg-primary/10 rounded transition-colors shrink-0"
|
||||
title="添加到映射"
|
||||
@click="addUpstreamModel(model.id)"
|
||||
>
|
||||
<ChevronRight class="w-4 h-4 text-muted-foreground hover:text-primary" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 未加载状态 -->
|
||||
<div
|
||||
v-else
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||
>
|
||||
<Zap class="w-10 h-10 mb-2 opacity-30" />
|
||||
<p class="text-sm">
|
||||
点击右上角按钮
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
从上游获取可用模型
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 右侧:映射名称列表 -->
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<p class="text-sm font-medium">
|
||||
映射名称
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
class="p-1.5 hover:bg-muted rounded-md transition-colors"
|
||||
title="手动添加"
|
||||
@click="addAliasItem"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
</div>
|
||||
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||
<div
|
||||
v-if="formData.aliases.length === 0"
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||
>
|
||||
<Tag class="w-10 h-10 mb-2 opacity-30" />
|
||||
<p class="text-sm">
|
||||
从左侧选择模型
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
或点击上方"手动添加"
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="p-2 space-y-1"
|
||||
>
|
||||
<div
|
||||
v-for="(alias, index) in formData.aliases"
|
||||
:key="`alias-${index}`"
|
||||
class="group flex items-center gap-2 p-2 rounded-lg border transition-colors hover:bg-muted/30"
|
||||
:class="[
|
||||
draggedIndex === index ? 'bg-primary/5' : '',
|
||||
dragOverIndex === index ? 'bg-primary/10 border-primary' : ''
|
||||
]"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart(index, $event)"
|
||||
@dragend="handleDragEnd"
|
||||
@dragover.prevent="handleDragOver(index)"
|
||||
@dragleave="handleDragLeave"
|
||||
@drop="handleDrop(index)"
|
||||
>
|
||||
<!-- 删除按钮 -->
|
||||
<button
|
||||
type="button"
|
||||
class="p-1 hover:bg-destructive/10 rounded transition-colors shrink-0"
|
||||
title="移除"
|
||||
@click="removeAliasItem(index)"
|
||||
>
|
||||
<ChevronLeft class="w-4 h-4 text-muted-foreground hover:text-destructive" />
|
||||
</button>
|
||||
|
||||
<!-- 优先级 -->
|
||||
<div class="shrink-0">
|
||||
<input
|
||||
v-if="editingPriorityIndex === index"
|
||||
type="number"
|
||||
min="1"
|
||||
:value="alias.priority"
|
||||
class="w-7 h-6 rounded bg-background border border-primary text-xs text-center focus:outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||
autofocus
|
||||
@blur="finishEditPriority(index, $event)"
|
||||
@keydown.enter="($event.target as HTMLInputElement).blur()"
|
||||
@keydown.escape="cancelEditPriority"
|
||||
>
|
||||
<div
|
||||
v-else
|
||||
class="w-6 h-6 rounded bg-muted/50 flex items-center justify-center text-xs text-muted-foreground cursor-pointer hover:bg-primary/10 hover:text-primary"
|
||||
title="点击编辑优先级"
|
||||
@click.stop="startEditPriority(index)"
|
||||
>
|
||||
{{ alias.priority }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 名称显示/编辑 -->
|
||||
<div class="flex-1 min-w-0">
|
||||
<Input
|
||||
v-if="alias.isEditing"
|
||||
v-model="alias.name"
|
||||
placeholder="输入映射名称"
|
||||
class="h-7 text-xs"
|
||||
autofocus
|
||||
@blur="alias.isEditing = false"
|
||||
@keydown.enter="alias.isEditing = false"
|
||||
/>
|
||||
<p
|
||||
v-else
|
||||
class="font-medium text-sm truncate cursor-pointer hover:text-primary"
|
||||
title="点击编辑"
|
||||
@click="alias.isEditing = true"
|
||||
>
|
||||
{{ alias.name || '点击输入名称' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 拖拽手柄 -->
|
||||
<div class="cursor-grab active:cursor-grabbing text-muted-foreground/30 group-hover:text-muted-foreground shrink-0">
|
||||
<GripVertical class="w-4 h-4" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- 拖拽提示 -->
|
||||
<div
|
||||
v-if="formData.aliases.length > 1"
|
||||
class="px-3 py-1.5 bg-muted/30 border-t text-xs text-muted-foreground text-center"
|
||||
>
|
||||
拖拽调整优先级顺序
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="$emit('update:open', false)"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting || !formData.modelId || formData.aliases.length === 0 || !hasValidAliases"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
{{ editingGroup ? '保存' : '添加' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import { Tag, Loader2, GripVertical, Zap, Search, RefreshCw, ChevronDown, ChevronRight, ChevronLeft, Plus } from 'lucide-vue-next'
|
||||
import {
|
||||
Button,
|
||||
Input,
|
||||
Label,
|
||||
Dialog,
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import {
|
||||
API_FORMAT_LABELS,
|
||||
type Model,
|
||||
type ProviderModelAlias
|
||||
} from '@/api/endpoints'
|
||||
import { updateModel } from '@/api/endpoints/models'
|
||||
import { useUpstreamModelsCache, type UpstreamModel } from '../composables/useUpstreamModelsCache'
|
||||
|
||||
interface FormAlias {
|
||||
name: string
|
||||
priority: number
|
||||
isEditing?: boolean
|
||||
}
|
||||
|
||||
export interface AliasGroup {
|
||||
model: Model
|
||||
apiFormatsKey: string
|
||||
apiFormats: string[]
|
||||
aliases: ProviderModelAlias[]
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
open: boolean
|
||||
providerId: string
|
||||
providerApiFormats: string[]
|
||||
models: Model[]
|
||||
editingGroup?: AliasGroup | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:open': [value: boolean]
|
||||
'saved': []
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
const { fetchModels: fetchCachedModels, clearCache, getCachedModels } = useUpstreamModelsCache()
|
||||
|
||||
// 状态
|
||||
const submitting = ref(false)
|
||||
const modelSelectOpen = ref(false)
|
||||
|
||||
// 拖拽状态
|
||||
const draggedIndex = ref<number | null>(null)
|
||||
const dragOverIndex = ref<number | null>(null)
|
||||
|
||||
// 优先级编辑状态
|
||||
const editingPriorityIndex = ref<number | null>(null)
|
||||
|
||||
// 快速添加(上游模型)状态
|
||||
const fetchingUpstreamModels = ref(false)
|
||||
const refreshingUpstreamModels = ref(false)
|
||||
const upstreamModelsLoaded = ref(false)
|
||||
const upstreamModels = ref<UpstreamModel[]>([])
|
||||
const upstreamModelSearch = ref('')
|
||||
|
||||
// 分组折叠状态
|
||||
const collapsedGroups = ref<Set<string>>(new Set())
|
||||
|
||||
// 表单数据
|
||||
const formData = ref<{
|
||||
modelId: string
|
||||
apiFormats: string[]
|
||||
aliases: FormAlias[]
|
||||
}>({
|
||||
modelId: '',
|
||||
apiFormats: [],
|
||||
aliases: []
|
||||
})
|
||||
|
||||
// 检查是否有有效的映射
|
||||
const hasValidAliases = computed(() => {
|
||||
return formData.value.aliases.some(a => a.name.trim())
|
||||
})
|
||||
|
||||
// 过滤和排序后的上游模型列表
|
||||
const filteredUpstreamModels = computed(() => {
|
||||
const searchText = upstreamModelSearch.value.toLowerCase().trim()
|
||||
let result = [...upstreamModels.value]
|
||||
|
||||
result.sort((a, b) => a.id.localeCompare(b.id))
|
||||
|
||||
if (searchText) {
|
||||
const keywords = searchText.split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.id} ${m.owned_by || ''} ${m.api_format || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 按 API 格式分组的上游模型列表
|
||||
interface UpstreamModelGroup {
|
||||
api_format: string
|
||||
models: Array<{ id: string; owned_by?: string; api_format?: string }>
|
||||
}
|
||||
|
||||
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
||||
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
||||
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
||||
|
||||
const groups = new Map<string, UpstreamModelGroup>()
|
||||
|
||||
for (const model of availableModels) {
|
||||
const format = model.api_format || 'UNKNOWN'
|
||||
if (!groups.has(format)) {
|
||||
groups.set(format, { api_format: format, models: [] })
|
||||
}
|
||||
groups.get(format)!.models.push(model)
|
||||
}
|
||||
|
||||
const order = Object.keys(API_FORMAT_LABELS)
|
||||
return Array.from(groups.values()).sort((a, b) => {
|
||||
const aIndex = order.indexOf(a.api_format)
|
||||
const bIndex = order.indexOf(b.api_format)
|
||||
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
|
||||
if (aIndex === -1) return 1
|
||||
if (bIndex === -1) return -1
|
||||
return aIndex - bIndex
|
||||
})
|
||||
})
|
||||
|
||||
// 监听打开状态
|
||||
watch(() => props.open, (isOpen) => {
|
||||
if (isOpen) {
|
||||
initForm()
|
||||
}
|
||||
})
|
||||
|
||||
// 初始化表单
|
||||
function initForm() {
|
||||
if (props.editingGroup) {
|
||||
formData.value = {
|
||||
modelId: props.editingGroup.model.id,
|
||||
apiFormats: [...props.editingGroup.apiFormats],
|
||||
aliases: props.editingGroup.aliases.map(a => ({ name: a.name, priority: a.priority }))
|
||||
}
|
||||
} else {
|
||||
formData.value = {
|
||||
modelId: '',
|
||||
apiFormats: [],
|
||||
aliases: []
|
||||
}
|
||||
}
|
||||
// 重置状态
|
||||
editingPriorityIndex.value = null
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
upstreamModelSearch.value = ''
|
||||
collapsedGroups.value = new Set()
|
||||
|
||||
// 检查缓存,如果有缓存数据则直接使用
|
||||
const cachedModels = getCachedModels(props.providerId)
|
||||
if (cachedModels) {
|
||||
upstreamModels.value = cachedModels
|
||||
upstreamModelsLoaded.value = true
|
||||
// 默认折叠所有分组
|
||||
for (const model of cachedModels) {
|
||||
if (model.api_format) {
|
||||
collapsedGroups.value.add(model.api_format)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
upstreamModelsLoaded.value = false
|
||||
upstreamModels.value = []
|
||||
}
|
||||
}
|
||||
|
||||
// 切换 API 格式
|
||||
function toggleApiFormat(format: string) {
|
||||
const index = formData.value.apiFormats.indexOf(format)
|
||||
if (index >= 0) {
|
||||
formData.value.apiFormats.splice(index, 1)
|
||||
} else {
|
||||
formData.value.apiFormats.push(format)
|
||||
}
|
||||
}
|
||||
|
||||
// 切换分组折叠状态
|
||||
function toggleGroupCollapse(apiFormat: string) {
|
||||
if (collapsedGroups.value.has(apiFormat)) {
|
||||
collapsedGroups.value.delete(apiFormat)
|
||||
} else {
|
||||
collapsedGroups.value.add(apiFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加映射项
|
||||
function addAliasItem() {
|
||||
const maxPriority = formData.value.aliases.length > 0
|
||||
? Math.max(...formData.value.aliases.map(a => a.priority))
|
||||
: 0
|
||||
formData.value.aliases.push({ name: '', priority: maxPriority + 1, isEditing: true })
|
||||
}
|
||||
|
||||
// 删除映射项
|
||||
function removeAliasItem(index: number) {
|
||||
formData.value.aliases.splice(index, 1)
|
||||
}
|
||||
|
||||
// ===== 拖拽排序 =====
|
||||
function handleDragStart(index: number, event: DragEvent) {
|
||||
draggedIndex.value = index
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'move'
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDragOver(index: number) {
|
||||
if (draggedIndex.value !== null && draggedIndex.value !== index) {
|
||||
dragOverIndex.value = index
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDrop(targetIndex: number) {
|
||||
const dragIndex = draggedIndex.value
|
||||
if (dragIndex === null || dragIndex === targetIndex) {
|
||||
dragOverIndex.value = null
|
||||
return
|
||||
}
|
||||
|
||||
const items = [...formData.value.aliases]
|
||||
const draggedItem = items[dragIndex]
|
||||
|
||||
const originalPriorityMap = new Map<number, number>()
|
||||
items.forEach((alias, idx) => {
|
||||
originalPriorityMap.set(idx, alias.priority)
|
||||
})
|
||||
|
||||
items.splice(dragIndex, 1)
|
||||
items.splice(targetIndex, 0, draggedItem)
|
||||
|
||||
const groupNewPriority = new Map<number, number>()
|
||||
let currentPriority = 1
|
||||
|
||||
items.forEach((alias) => {
|
||||
const originalIdx = formData.value.aliases.findIndex(a => a === alias)
|
||||
const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority
|
||||
|
||||
if (alias === draggedItem) {
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
} else {
|
||||
if (groupNewPriority.has(originalPriority)) {
|
||||
alias.priority = groupNewPriority.get(originalPriority)!
|
||||
} else {
|
||||
groupNewPriority.set(originalPriority, currentPriority)
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
formData.value.aliases = items
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
// ===== 优先级编辑 =====
|
||||
function startEditPriority(index: number) {
|
||||
editingPriorityIndex.value = index
|
||||
}
|
||||
|
||||
function finishEditPriority(index: number, event: FocusEvent) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const newPriority = parseInt(input.value) || 1
|
||||
formData.value.aliases[index].priority = Math.max(1, newPriority)
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
function cancelEditPriority() {
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
// ===== 快速添加(上游模型)=====
|
||||
async function fetchUpstreamModels() {
|
||||
if (!props.providerId) return
|
||||
|
||||
upstreamModelSearch.value = ''
|
||||
fetchingUpstreamModels.value = true
|
||||
|
||||
try {
|
||||
const result = await fetchCachedModels(props.providerId)
|
||||
if (result) {
|
||||
if (result.error) {
|
||||
showError(result.error, '错误')
|
||||
} else {
|
||||
upstreamModels.value = result.models
|
||||
upstreamModelsLoaded.value = true
|
||||
// 默认折叠所有分组
|
||||
for (const model of result.models) {
|
||||
if (model.api_format) {
|
||||
collapsedGroups.value.add(model.api_format)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
fetchingUpstreamModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function addUpstreamModel(modelId: string) {
|
||||
if (formData.value.aliases.some(a => a.name === modelId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const maxPriority = formData.value.aliases.length > 0
|
||||
? Math.max(...formData.value.aliases.map(a => a.priority))
|
||||
: 0
|
||||
|
||||
formData.value.aliases.push({ name: modelId, priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
async function refreshUpstreamModels() {
|
||||
if (!props.providerId || refreshingUpstreamModels.value) return
|
||||
|
||||
refreshingUpstreamModels.value = true
|
||||
clearCache(props.providerId)
|
||||
|
||||
try {
|
||||
const result = await fetchCachedModels(props.providerId, true)
|
||||
if (result) {
|
||||
if (result.error) {
|
||||
showError(result.error, '错误')
|
||||
} else {
|
||||
upstreamModels.value = result.models
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
refreshingUpstreamModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 生成作用域唯一键
|
||||
function getApiFormatsKey(formats: string[] | undefined): string {
|
||||
if (!formats || formats.length === 0) return ''
|
||||
return [...formats].sort().join(',')
|
||||
}
|
||||
|
||||
// 提交表单
|
||||
async function handleSubmit() {
|
||||
if (submitting.value) return
|
||||
if (!formData.value.modelId || formData.value.aliases.length === 0) return
|
||||
|
||||
const validAliases = formData.value.aliases.filter(a => a.name.trim())
|
||||
if (validAliases.length === 0) {
|
||||
showError('请至少添加一个有效的映射名称', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const targetModel = props.models.find(m => m.id === formData.value.modelId)
|
||||
if (!targetModel) {
|
||||
showError('模型不存在', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
const currentAliases = targetModel.provider_model_mappings || []
|
||||
let newAliases: ProviderModelAlias[]
|
||||
|
||||
const buildAlias = (a: FormAlias): ProviderModelAlias => ({
|
||||
name: a.name.trim(),
|
||||
priority: a.priority,
|
||||
...(formData.value.apiFormats.length > 0 ? { api_formats: formData.value.apiFormats } : {})
|
||||
})
|
||||
|
||||
if (props.editingGroup) {
|
||||
const oldApiFormatsKey = props.editingGroup.apiFormatsKey
|
||||
const oldAliasNames = new Set(props.editingGroup.aliases.map(a => a.name))
|
||||
|
||||
const filteredAliases = currentAliases.filter((a: ProviderModelAlias) => {
|
||||
const currentKey = getApiFormatsKey(a.api_formats)
|
||||
return !(currentKey === oldApiFormatsKey && oldAliasNames.has(a.name))
|
||||
})
|
||||
|
||||
const existingNames = new Set(filteredAliases.map((a: ProviderModelAlias) => a.name))
|
||||
const duplicates = validAliases.filter(a => existingNames.has(a.name.trim()))
|
||||
if (duplicates.length > 0) {
|
||||
showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误')
|
||||
return
|
||||
}
|
||||
|
||||
newAliases = [
|
||||
...filteredAliases,
|
||||
...validAliases.map(buildAlias)
|
||||
]
|
||||
} else {
|
||||
const existingNames = new Set(currentAliases.map((a: ProviderModelAlias) => a.name))
|
||||
const duplicates = validAliases.filter(a => existingNames.has(a.name.trim()))
|
||||
if (duplicates.length > 0) {
|
||||
showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误')
|
||||
return
|
||||
}
|
||||
newAliases = [
|
||||
...currentAliases,
|
||||
...validAliases.map(buildAlias)
|
||||
]
|
||||
}
|
||||
|
||||
await updateModel(props.providerId, targetModel.id, {
|
||||
provider_model_mappings: newAliases
|
||||
})
|
||||
|
||||
showSuccess(props.editingGroup ? '映射组已更新' : '映射已添加')
|
||||
emit('update:open', false)
|
||||
emit('saved')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -483,9 +483,9 @@
|
||||
<span
|
||||
v-if="key.max_concurrent || key.is_adaptive"
|
||||
class="text-muted-foreground"
|
||||
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'})` : '固定并发限制'"
|
||||
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'})` : `固定并发限制: ${key.max_concurrent}`"
|
||||
>
|
||||
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.learned_max_concurrent || key.max_concurrent || 3 }}
|
||||
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.is_adaptive ? (key.learned_max_concurrent ?? '学习中') : key.max_concurrent }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
@@ -655,6 +655,7 @@ import {
|
||||
GripVertical,
|
||||
Copy
|
||||
} from 'lucide-vue-next'
|
||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
@@ -1296,6 +1297,16 @@ async function loadEndpoints() {
|
||||
showError(err.response?.data?.detail || '加载端点失败', '错误')
|
||||
}
|
||||
}
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (props.open) {
|
||||
handleClose()
|
||||
}
|
||||
}, {
|
||||
disableOnInput: true,
|
||||
once: false
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -101,24 +101,24 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 展开的别名列表 -->
|
||||
<!-- 展开的映射列表 -->
|
||||
<div
|
||||
v-show="expandedAliasGroups.has(`${group.model.id}-${group.apiFormatsKey}`)"
|
||||
class="bg-muted/30 border-t border-border/30"
|
||||
>
|
||||
<div class="px-4 py-2 space-y-1">
|
||||
<div
|
||||
v-for="alias in group.aliases"
|
||||
:key="alias.name"
|
||||
v-for="mapping in group.aliases"
|
||||
:key="mapping.name"
|
||||
class="flex items-center gap-2 py-1"
|
||||
>
|
||||
<!-- 优先级标签 -->
|
||||
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
||||
{{ alias.priority }}
|
||||
{{ mapping.priority }}
|
||||
</span>
|
||||
<!-- 别名名称 -->
|
||||
<!-- 映射名称 -->
|
||||
<span class="font-mono text-sm truncate">
|
||||
{{ alias.name }}
|
||||
{{ mapping.name }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
@@ -142,330 +142,14 @@
|
||||
</Card>
|
||||
|
||||
<!-- 添加/编辑映射对话框 -->
|
||||
<Dialog
|
||||
v-model="dialogOpen"
|
||||
:title="editingItem ? '编辑模型映射' : '添加模型映射'"
|
||||
:description="editingItem ? '修改映射配置' : '为模型添加新的名称映射'"
|
||||
:icon="Tag"
|
||||
size="xl"
|
||||
>
|
||||
<div class="space-y-3">
|
||||
<!-- 第一行:目标模型 | 作用域 -->
|
||||
<div class="flex gap-4">
|
||||
<!-- 目标模型 -->
|
||||
<div class="flex-1 space-y-1.5">
|
||||
<Label class="text-xs">目标模型</Label>
|
||||
<Select
|
||||
v-model:open="modelSelectOpen"
|
||||
:model-value="formData.modelId"
|
||||
:disabled="!!editingItem"
|
||||
@update:model-value="formData.modelId = $event"
|
||||
>
|
||||
<SelectTrigger class="h-9">
|
||||
<SelectValue placeholder="请选择模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in models"
|
||||
:key="model.id"
|
||||
:value="model.id"
|
||||
>
|
||||
{{ model.global_model_display_name || model.provider_model_name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<!-- 作用域 -->
|
||||
<div class="flex-1 space-y-1.5">
|
||||
<Label class="text-xs">作用域 <span class="text-muted-foreground font-normal">(不选则适用全部)</span></Label>
|
||||
<div
|
||||
v-if="providerApiFormats.length > 0"
|
||||
class="flex flex-wrap gap-1.5 p-2 rounded-md border bg-muted/30 min-h-[36px]"
|
||||
>
|
||||
<button
|
||||
v-for="format in providerApiFormats"
|
||||
:key="format"
|
||||
type="button"
|
||||
class="px-2.5 py-0.5 rounded text-xs font-medium transition-colors"
|
||||
:class="[
|
||||
formData.apiFormats.includes(format)
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: 'bg-background border border-border hover:bg-muted'
|
||||
]"
|
||||
@click="toggleApiFormat(format)"
|
||||
>
|
||||
{{ API_FORMAT_LABELS[format] || format }}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="h-9 flex items-center text-xs text-muted-foreground"
|
||||
>
|
||||
无可用格式
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 第二行:上游模型 | 映射名称 -->
|
||||
<div class="flex gap-4 h-[340px]">
|
||||
<!-- 左侧:上游模型列表 -->
|
||||
<div class="flex-1 flex flex-col border rounded-lg overflow-hidden">
|
||||
<!-- 左侧头部:标题 + 搜索 + 操作按钮 -->
|
||||
<div class="px-3 py-2 bg-muted/50 border-b flex items-center gap-2 shrink-0">
|
||||
<span class="text-xs font-medium shrink-0">上游模型</span>
|
||||
<!-- 搜索框 -->
|
||||
<div class="flex-1 relative">
|
||||
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
|
||||
<Input
|
||||
v-model="upstreamModelSearch"
|
||||
placeholder="搜索模型..."
|
||||
class="pl-7 h-7 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<!-- 操作按钮 -->
|
||||
<button
|
||||
v-if="upstreamModelsLoaded"
|
||||
class="p-1.5 rounded hover:bg-muted transition-colors shrink-0"
|
||||
title="刷新列表"
|
||||
:disabled="refreshingUpstreamModels"
|
||||
@click="refreshUpstreamModels"
|
||||
>
|
||||
<RefreshCw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="{ 'animate-spin': refreshingUpstreamModels }"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
v-else-if="!fetchingUpstreamModels"
|
||||
class="p-1.5 rounded hover:bg-muted transition-colors shrink-0"
|
||||
title="获取上游模型列表"
|
||||
@click="fetchUpstreamModels"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
<Loader2
|
||||
v-else
|
||||
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- 模型列表 -->
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<template v-if="upstreamModelsLoaded">
|
||||
<!-- 按分组显示(可折叠) -->
|
||||
<div
|
||||
v-for="group in groupedAvailableUpstreamModels"
|
||||
:key="group.api_format"
|
||||
>
|
||||
<div
|
||||
class="sticky top-0 z-10 px-3 py-1.5 bg-muted/80 backdrop-blur-sm border-b flex items-center justify-between cursor-pointer hover:bg-muted/90 transition-colors"
|
||||
@click="toggleGroupCollapse(group.api_format)"
|
||||
>
|
||||
<div class="flex items-center gap-1.5">
|
||||
<ChevronRight
|
||||
class="w-3.5 h-3.5 transition-transform"
|
||||
:class="{ 'rotate-90': !collapsedGroups.has(group.api_format) }"
|
||||
/>
|
||||
<span class="text-xs font-medium">{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}</span>
|
||||
<span class="text-xs text-muted-foreground">({{ group.models.length }})</span>
|
||||
</div>
|
||||
<button
|
||||
class="text-xs text-primary hover:underline"
|
||||
@click.stop="addAllFromGroup(group.api_format)"
|
||||
>
|
||||
全部添加
|
||||
</button>
|
||||
</div>
|
||||
<div v-show="!collapsedGroups.has(group.api_format)">
|
||||
<div
|
||||
v-for="model in group.models"
|
||||
:key="model.id"
|
||||
class="group flex items-center gap-2 px-3 py-1.5 hover:bg-muted/50 cursor-pointer transition-colors"
|
||||
:title="model.id"
|
||||
@click="addUpstreamModel(model.id)"
|
||||
>
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="font-mono text-xs truncate">
|
||||
{{ model.id }}
|
||||
</div>
|
||||
<div
|
||||
v-if="model.owned_by"
|
||||
class="text-xs text-muted-foreground truncate"
|
||||
>
|
||||
{{ model.owned_by }}
|
||||
</div>
|
||||
</div>
|
||||
<Plus class="w-3.5 h-3.5 text-muted-foreground/50 group-hover:text-primary transition-colors shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<div
|
||||
v-if="groupedAvailableUpstreamModels.length === 0"
|
||||
class="flex items-center justify-center h-full text-muted-foreground text-xs p-4"
|
||||
>
|
||||
{{ upstreamModelSearch ? '没有匹配的模型' : '所有模型已添加' }}
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 未加载状态 -->
|
||||
<div
|
||||
v-else
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground p-4"
|
||||
>
|
||||
<Zap class="w-8 h-8 mb-2 opacity-30" />
|
||||
<p class="text-xs text-center">
|
||||
点击右上角按钮<br>从上游获取可用模型
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 右侧:映射模型(编辑模式下全宽) -->
|
||||
<div class="flex-1 flex flex-col border rounded-lg overflow-hidden">
|
||||
<div class="px-3 py-2 bg-primary/5 border-b flex items-center justify-between shrink-0">
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs font-medium">映射名称</span>
|
||||
<Badge
|
||||
v-if="formData.aliases.length > 0"
|
||||
variant="secondary"
|
||||
class="text-xs h-5"
|
||||
>
|
||||
{{ formData.aliases.length }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-1">
|
||||
<button
|
||||
v-if="formData.aliases.length > 0"
|
||||
class="p-1.5 rounded hover:bg-muted text-muted-foreground hover:text-destructive transition-colors"
|
||||
title="清空"
|
||||
@click="formData.aliases = []"
|
||||
>
|
||||
<Eraser class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-muted transition-colors"
|
||||
title="手动添加"
|
||||
@click="addAliasItem"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 已选列表 -->
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<div
|
||||
v-if="formData.aliases.length > 0"
|
||||
class="divide-y divide-border/30"
|
||||
>
|
||||
<div
|
||||
v-for="(alias, index) in formData.aliases"
|
||||
:key="`alias-${index}`"
|
||||
class="group flex items-center gap-1.5 px-2 py-1.5 hover:bg-muted/30 transition-colors"
|
||||
:class="[
|
||||
draggedIndex === index ? 'bg-primary/5' : '',
|
||||
dragOverIndex === index ? 'bg-primary/10' : ''
|
||||
]"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart(index, $event)"
|
||||
@dragend="handleDragEnd"
|
||||
@dragover.prevent="handleDragOver(index)"
|
||||
@dragleave="handleDragLeave"
|
||||
@drop="handleDrop(index)"
|
||||
>
|
||||
<!-- 拖拽手柄 -->
|
||||
<div class="cursor-grab active:cursor-grabbing text-muted-foreground/30 group-hover:text-muted-foreground shrink-0">
|
||||
<GripVertical class="w-3 h-3" />
|
||||
</div>
|
||||
|
||||
<!-- 优先级 -->
|
||||
<div class="shrink-0">
|
||||
<input
|
||||
v-if="editingPriorityIndex === index"
|
||||
type="number"
|
||||
min="1"
|
||||
:value="alias.priority"
|
||||
class="w-6 h-5 rounded bg-background border border-primary text-xs text-center focus:outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||
autofocus
|
||||
@blur="finishEditPriority(index, $event)"
|
||||
@keydown.enter="($event.target as HTMLInputElement).blur()"
|
||||
@keydown.escape="cancelEditPriority"
|
||||
>
|
||||
<div
|
||||
v-else
|
||||
class="w-5 h-5 rounded bg-muted/50 flex items-center justify-center text-xs text-muted-foreground cursor-pointer hover:bg-primary/10 hover:text-primary"
|
||||
title="点击编辑优先级"
|
||||
@click.stop="startEditPriority(index)"
|
||||
>
|
||||
{{ alias.priority }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 名称输入 -->
|
||||
<Input
|
||||
v-model="alias.name"
|
||||
placeholder="映射名称"
|
||||
class="flex-1 h-6 text-xs px-2"
|
||||
/>
|
||||
|
||||
<!-- 删除按钮 -->
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="shrink-0 text-muted-foreground hover:text-destructive h-5 w-5 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
@click="removeAliasItem(index)"
|
||||
>
|
||||
<X class="w-3 h-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<div
|
||||
v-else
|
||||
class="flex flex-col items-center justify-center h-full text-muted-foreground p-4"
|
||||
>
|
||||
<Tag class="w-8 h-8 mb-2 opacity-30" />
|
||||
<p class="text-xs text-center">
|
||||
从左侧选择模型<br>或手动添加映射
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 拖拽提示 -->
|
||||
<div
|
||||
v-if="formData.aliases.length > 1"
|
||||
class="px-3 py-1.5 bg-muted/30 border-t text-xs text-muted-foreground text-center shrink-0"
|
||||
>
|
||||
拖拽调整优先级顺序
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="dialogOpen = false"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting || !formData.modelId || formData.aliases.length === 0 || !hasValidAliases"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
{{ editingItem ? '保存' : '添加' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
<ModelMappingDialog
|
||||
v-model:open="dialogOpen"
|
||||
:provider-id="provider.id"
|
||||
:provider-api-formats="providerApiFormats"
|
||||
:models="models"
|
||||
:editing-group="editingGroup"
|
||||
@saved="onDialogSaved"
|
||||
/>
|
||||
|
||||
<!-- 删除确认对话框 -->
|
||||
<AlertDialog
|
||||
@@ -482,21 +166,10 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import { Tag, Plus, Edit, Trash2, Loader2, GripVertical, X, Zap, Search, RefreshCw, ChevronRight, Eraser } from 'lucide-vue-next'
|
||||
import {
|
||||
Card,
|
||||
Button,
|
||||
Badge,
|
||||
Input,
|
||||
Label,
|
||||
Dialog,
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui'
|
||||
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next'
|
||||
import { Card, Button, Badge } from '@/components/ui'
|
||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import {
|
||||
getProviderModels,
|
||||
@@ -505,17 +178,6 @@ import {
|
||||
type ProviderModelAlias
|
||||
} from '@/api/endpoints'
|
||||
import { updateModel } from '@/api/endpoints/models'
|
||||
import { adminApi } from '@/api/admin'
|
||||
|
||||
interface AliasItem {
|
||||
model: Model
|
||||
alias: ProviderModelAlias
|
||||
}
|
||||
|
||||
interface FormAlias {
|
||||
name: string
|
||||
priority: number
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
provider: any
|
||||
@@ -532,131 +194,22 @@ const loading = ref(false)
|
||||
const models = ref<Model[]>([])
|
||||
const dialogOpen = ref(false)
|
||||
const deleteConfirmOpen = ref(false)
|
||||
const submitting = ref(false)
|
||||
const editingItem = ref<AliasItem | null>(null)
|
||||
const editingGroup = ref<AliasGroup | null>(null)
|
||||
const deletingGroup = ref<AliasGroup | null>(null)
|
||||
const modelSelectOpen = ref(false)
|
||||
|
||||
// 拖拽状态
|
||||
const draggedIndex = ref<number | null>(null)
|
||||
const dragOverIndex = ref<number | null>(null)
|
||||
|
||||
// 优先级编辑状态
|
||||
const editingPriorityIndex = ref<number | null>(null)
|
||||
|
||||
// 快速添加(上游模型)状态
|
||||
const fetchingUpstreamModels = ref(false)
|
||||
const refreshingUpstreamModels = ref(false)
|
||||
const upstreamModelsLoaded = ref(false)
|
||||
const upstreamModels = ref<Array<{ id: string; owned_by?: string; api_format?: string }>>([])
|
||||
const upstreamModelSearch = ref('')
|
||||
|
||||
// 分组折叠状态(上游模型列表)
|
||||
const collapsedGroups = ref<Set<string>>(new Set())
|
||||
|
||||
// 列表展开状态(映射组列表)
|
||||
// 列表展开状态
|
||||
const expandedAliasGroups = ref<Set<string>>(new Set())
|
||||
|
||||
// 上游模型缓存(按 Provider ID)
|
||||
const upstreamModelsCache = ref<Map<string, {
|
||||
models: Array<{ id: string; owned_by?: string; api_format?: string }>
|
||||
timestamp: number
|
||||
}>>(new Map())
|
||||
const CACHE_TTL = 5 * 60 * 1000 // 5 分钟缓存
|
||||
|
||||
// 过滤和排序后的上游模型列表
|
||||
const filteredUpstreamModels = computed(() => {
|
||||
const searchText = upstreamModelSearch.value.toLowerCase().trim()
|
||||
let result = [...upstreamModels.value]
|
||||
|
||||
// 按名称排序
|
||||
result.sort((a, b) => a.id.localeCompare(b.id))
|
||||
|
||||
// 搜索过滤(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchText) {
|
||||
const keywords = searchText.split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.id} ${m.owned_by || ''} ${m.api_format || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 按 API 格式分组的上游模型列表
|
||||
interface UpstreamModelGroup {
|
||||
api_format: string
|
||||
models: Array<{ id: string; owned_by?: string; api_format?: string }>
|
||||
}
|
||||
|
||||
// 可添加的上游模型(排除已添加的)按分组显示
|
||||
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
||||
// 获取已添加的映射名称集合
|
||||
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
||||
|
||||
// 过滤掉已添加的模型
|
||||
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
||||
|
||||
// 按 API 格式分组
|
||||
const groups = new Map<string, UpstreamModelGroup>()
|
||||
|
||||
for (const model of availableModels) {
|
||||
const format = model.api_format || 'UNKNOWN'
|
||||
if (!groups.has(format)) {
|
||||
groups.set(format, { api_format: format, models: [] })
|
||||
}
|
||||
groups.get(format)!.models.push(model)
|
||||
}
|
||||
|
||||
// 按 API_FORMAT_LABELS 的键顺序排序
|
||||
const order = Object.keys(API_FORMAT_LABELS)
|
||||
return Array.from(groups.values()).sort((a, b) => {
|
||||
const aIndex = order.indexOf(a.api_format)
|
||||
const bIndex = order.indexOf(b.api_format)
|
||||
// 未知格式排最后
|
||||
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
|
||||
if (aIndex === -1) return 1
|
||||
if (bIndex === -1) return -1
|
||||
return aIndex - bIndex
|
||||
})
|
||||
})
|
||||
|
||||
// 表单数据
|
||||
const formData = ref<{
|
||||
modelId: string
|
||||
apiFormats: string[]
|
||||
aliases: FormAlias[]
|
||||
}>({
|
||||
modelId: '',
|
||||
apiFormats: [],
|
||||
aliases: []
|
||||
})
|
||||
|
||||
// 检查是否有有效的别名
|
||||
const hasValidAliases = computed(() => {
|
||||
return formData.value.aliases.some(a => a.name.trim())
|
||||
})
|
||||
|
||||
// 获取 Provider 支持的 API 格式(按 API_FORMATS 定义的顺序排序)
|
||||
// 获取 Provider 支持的 API 格式
|
||||
const providerApiFormats = computed(() => {
|
||||
const formats = props.provider?.api_formats
|
||||
if (Array.isArray(formats) && formats.length > 0) {
|
||||
// 按 API_FORMAT_LABELS 中的键顺序排序
|
||||
const order = Object.keys(API_FORMAT_LABELS)
|
||||
return [...formats].sort((a, b) => order.indexOf(a) - order.indexOf(b))
|
||||
}
|
||||
return []
|
||||
})
|
||||
|
||||
// 分组数据结构
|
||||
interface AliasGroup {
|
||||
model: Model
|
||||
apiFormatsKey: string // 作用域的唯一标识(排序后的格式数组 JSON)
|
||||
apiFormats: string[] // 作用域
|
||||
aliases: ProviderModelAlias[] // 该组的所有映射
|
||||
}
|
||||
|
||||
// 生成作用域唯一键
|
||||
function getApiFormatsKey(formats: string[] | undefined): string {
|
||||
if (!formats || formats.length === 0) return ''
|
||||
@@ -669,9 +222,9 @@ const aliasGroups = computed<AliasGroup[]>(() => {
|
||||
const groupMap = new Map<string, AliasGroup>()
|
||||
|
||||
for (const model of models.value) {
|
||||
if (!model.provider_model_aliases || !Array.isArray(model.provider_model_aliases)) continue
|
||||
if (!model.provider_model_mappings || !Array.isArray(model.provider_model_mappings)) continue
|
||||
|
||||
for (const alias of model.provider_model_aliases) {
|
||||
for (const alias of model.provider_model_mappings) {
|
||||
const apiFormatsKey = getApiFormatsKey(alias.api_formats)
|
||||
const groupKey = `${model.id}|${apiFormatsKey}`
|
||||
|
||||
@@ -689,12 +242,10 @@ const aliasGroups = computed<AliasGroup[]>(() => {
|
||||
}
|
||||
}
|
||||
|
||||
// 对每个组内的别名按优先级排序
|
||||
for (const group of groups) {
|
||||
group.aliases.sort((a, b) => a.priority - b.priority)
|
||||
}
|
||||
|
||||
// 按模型名排序,同模型内按作用域排序
|
||||
return groups.sort((a, b) => {
|
||||
const nameA = (a.model.global_model_display_name || a.model.provider_model_name || '').toLowerCase()
|
||||
const nameB = (b.model.global_model_display_name || b.model.provider_model_name || '').toLowerCase()
|
||||
@@ -703,9 +254,6 @@ const aliasGroups = computed<AliasGroup[]>(() => {
|
||||
})
|
||||
})
|
||||
|
||||
// 当前编辑的分组
|
||||
const editingGroup = ref<AliasGroup | null>(null)
|
||||
|
||||
// 加载模型
|
||||
async function loadModels() {
|
||||
try {
|
||||
@@ -728,25 +276,6 @@ const deleteConfirmDescription = computed(() => {
|
||||
return `确定要删除模型「${modelName}」在作用域「${scopeText}」下的 ${aliases.length} 个映射吗?\n\n映射名称:${aliasNames}`
|
||||
})
|
||||
|
||||
// 切换 API 格式
|
||||
function toggleApiFormat(format: string) {
|
||||
const index = formData.value.apiFormats.indexOf(format)
|
||||
if (index >= 0) {
|
||||
formData.value.apiFormats.splice(index, 1)
|
||||
} else {
|
||||
formData.value.apiFormats.push(format)
|
||||
}
|
||||
}
|
||||
|
||||
// 切换分组折叠状态(上游模型列表)
|
||||
function toggleGroupCollapse(apiFormat: string) {
|
||||
if (collapsedGroups.value.has(apiFormat)) {
|
||||
collapsedGroups.value.delete(apiFormat)
|
||||
} else {
|
||||
collapsedGroups.value.add(apiFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// 切换映射组展开状态
|
||||
function toggleAliasGroupExpand(groupKey: string) {
|
||||
if (expandedAliasGroups.value.has(groupKey)) {
|
||||
@@ -756,147 +285,15 @@ function toggleAliasGroupExpand(groupKey: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 添加别名项
|
||||
function addAliasItem() {
|
||||
const maxPriority = formData.value.aliases.length > 0
|
||||
? Math.max(...formData.value.aliases.map(a => a.priority))
|
||||
: 0
|
||||
formData.value.aliases.push({ name: '', priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
// 删除别名项
|
||||
function removeAliasItem(index: number) {
|
||||
formData.value.aliases.splice(index, 1)
|
||||
}
|
||||
|
||||
// ===== 拖拽排序 =====
|
||||
function handleDragStart(index: number, event: DragEvent) {
|
||||
draggedIndex.value = index
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'move'
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDragOver(index: number) {
|
||||
if (draggedIndex.value !== null && draggedIndex.value !== index) {
|
||||
dragOverIndex.value = index
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDrop(targetIndex: number) {
|
||||
const dragIndex = draggedIndex.value
|
||||
if (dragIndex === null || dragIndex === targetIndex) {
|
||||
dragOverIndex.value = null
|
||||
return
|
||||
}
|
||||
|
||||
const items = [...formData.value.aliases]
|
||||
const draggedItem = items[dragIndex]
|
||||
|
||||
// 记录每个别名的原始优先级(在修改前)
|
||||
const originalPriorityMap = new Map<number, number>()
|
||||
items.forEach((alias, idx) => {
|
||||
originalPriorityMap.set(idx, alias.priority)
|
||||
})
|
||||
|
||||
// 重排数组
|
||||
items.splice(dragIndex, 1)
|
||||
items.splice(targetIndex, 0, draggedItem)
|
||||
|
||||
// 按新顺序为每个组分配新的优先级
|
||||
// 同组的别名保持相同的优先级(被拖动的别名单独成组)
|
||||
const groupNewPriority = new Map<number, number>() // 原优先级 -> 新优先级
|
||||
let currentPriority = 1
|
||||
|
||||
items.forEach((alias) => {
|
||||
// 找到这个别名在原数组中的索引
|
||||
const originalIdx = formData.value.aliases.findIndex(a => a === alias)
|
||||
const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority
|
||||
|
||||
if (alias === draggedItem) {
|
||||
// 被拖动的别名是独立的新组,获得当前优先级
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
} else {
|
||||
if (groupNewPriority.has(originalPriority)) {
|
||||
// 这个组已经分配过优先级,使用相同的值
|
||||
alias.priority = groupNewPriority.get(originalPriority)!
|
||||
} else {
|
||||
// 这个组第一次出现,分配新优先级
|
||||
groupNewPriority.set(originalPriority, currentPriority)
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
formData.value.aliases = items
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
// ===== 优先级编辑 =====
|
||||
function startEditPriority(index: number) {
|
||||
editingPriorityIndex.value = index
|
||||
}
|
||||
|
||||
function finishEditPriority(index: number, event: FocusEvent) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const newPriority = parseInt(input.value) || 1
|
||||
formData.value.aliases[index].priority = Math.max(1, newPriority)
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
function cancelEditPriority() {
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
// 打开添加对话框
|
||||
function openAddDialog() {
|
||||
editingItem.value = null
|
||||
editingGroup.value = null
|
||||
formData.value = {
|
||||
modelId: '',
|
||||
apiFormats: [],
|
||||
aliases: []
|
||||
}
|
||||
// 重置状态
|
||||
editingPriorityIndex.value = null
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
// 重置上游模型状态
|
||||
upstreamModelsLoaded.value = false
|
||||
upstreamModels.value = []
|
||||
upstreamModelSearch.value = ''
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
// 编辑分组
|
||||
function editGroup(group: AliasGroup) {
|
||||
editingGroup.value = group
|
||||
editingItem.value = { model: group.model, alias: group.aliases[0] } // 保持兼容
|
||||
formData.value = {
|
||||
modelId: group.model.id,
|
||||
apiFormats: [...group.apiFormats],
|
||||
aliases: group.aliases.map(a => ({ name: a.name, priority: a.priority }))
|
||||
}
|
||||
// 重置状态
|
||||
editingPriorityIndex.value = null
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
// 重置上游模型状态
|
||||
upstreamModelsLoaded.value = false
|
||||
upstreamModels.value = []
|
||||
upstreamModelSearch.value = ''
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
@@ -913,17 +310,15 @@ async function confirmDelete() {
|
||||
const { model, aliases, apiFormatsKey } = deletingGroup.value
|
||||
|
||||
try {
|
||||
// 从模型的别名列表中移除该分组的所有别名
|
||||
const currentAliases = model.provider_model_aliases || []
|
||||
const currentAliases = model.provider_model_mappings || []
|
||||
const aliasNamesToRemove = new Set(aliases.map(a => a.name))
|
||||
const newAliases = currentAliases.filter((a: ProviderModelAlias) => {
|
||||
// 只移除同一作用域的别名
|
||||
const currentKey = getApiFormatsKey(a.api_formats)
|
||||
return !(currentKey === apiFormatsKey && aliasNamesToRemove.has(a.name))
|
||||
})
|
||||
|
||||
await updateModel(props.provider.id, model.id, {
|
||||
provider_model_aliases: newAliases.length > 0 ? newAliases : null
|
||||
provider_model_mappings: newAliases.length > 0 ? newAliases : null
|
||||
})
|
||||
|
||||
showSuccess('映射组已删除')
|
||||
@@ -936,89 +331,10 @@ async function confirmDelete() {
|
||||
}
|
||||
}
|
||||
|
||||
// 提交表单
|
||||
async function handleSubmit() {
|
||||
if (submitting.value) return
|
||||
if (!formData.value.modelId || formData.value.aliases.length === 0) return
|
||||
|
||||
// 过滤有效的别名
|
||||
const validAliases = formData.value.aliases.filter(a => a.name.trim())
|
||||
if (validAliases.length === 0) {
|
||||
showError('请至少添加一个有效的映射名称', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const targetModel = models.value.find(m => m.id === formData.value.modelId)
|
||||
if (!targetModel) {
|
||||
showError('模型不存在', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
const currentAliases = targetModel.provider_model_aliases || []
|
||||
let newAliases: ProviderModelAlias[]
|
||||
|
||||
// 构建新的别名对象(带作用域)
|
||||
const buildAlias = (a: FormAlias): ProviderModelAlias => ({
|
||||
name: a.name.trim(),
|
||||
priority: a.priority,
|
||||
...(formData.value.apiFormats.length > 0 ? { api_formats: formData.value.apiFormats } : {})
|
||||
})
|
||||
|
||||
if (editingGroup.value) {
|
||||
// 编辑分组模式:替换该分组的所有别名
|
||||
const oldApiFormatsKey = editingGroup.value.apiFormatsKey
|
||||
const oldAliasNames = new Set(editingGroup.value.aliases.map(a => a.name))
|
||||
|
||||
// 移除旧分组的所有别名
|
||||
const filteredAliases = currentAliases.filter((a: ProviderModelAlias) => {
|
||||
const currentKey = getApiFormatsKey(a.api_formats)
|
||||
return !(currentKey === oldApiFormatsKey && oldAliasNames.has(a.name))
|
||||
})
|
||||
|
||||
// 检查新别名是否与其他分组的别名重复
|
||||
const existingNames = new Set(filteredAliases.map((a: ProviderModelAlias) => a.name))
|
||||
const duplicates = validAliases.filter(a => existingNames.has(a.name.trim()))
|
||||
if (duplicates.length > 0) {
|
||||
showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误')
|
||||
return
|
||||
}
|
||||
|
||||
// 添加新的别名
|
||||
newAliases = [
|
||||
...filteredAliases,
|
||||
...validAliases.map(buildAlias)
|
||||
]
|
||||
} else {
|
||||
// 添加模式:检查是否重复并批量添加
|
||||
const existingNames = new Set(currentAliases.map((a: ProviderModelAlias) => a.name))
|
||||
const duplicates = validAliases.filter(a => existingNames.has(a.name.trim()))
|
||||
if (duplicates.length > 0) {
|
||||
showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误')
|
||||
return
|
||||
}
|
||||
newAliases = [
|
||||
...currentAliases,
|
||||
...validAliases.map(buildAlias)
|
||||
]
|
||||
}
|
||||
|
||||
await updateModel(props.provider.id, targetModel.id, {
|
||||
provider_model_aliases: newAliases
|
||||
})
|
||||
|
||||
showSuccess(editingGroup.value ? '映射组已更新' : '映射已添加')
|
||||
dialogOpen.value = false
|
||||
editingGroup.value = null
|
||||
editingItem.value = null
|
||||
await loadModels()
|
||||
emit('refresh')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
// 对话框保存后回调
|
||||
async function onDialogSaved() {
|
||||
await loadModels()
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 监听 provider 变化
|
||||
@@ -1033,103 +349,4 @@ onMounted(() => {
|
||||
loadModels()
|
||||
}
|
||||
})
|
||||
|
||||
// ===== 快速添加(上游模型)=====
|
||||
async function fetchUpstreamModels() {
|
||||
if (!props.provider?.id) return
|
||||
|
||||
const providerId = props.provider.id
|
||||
upstreamModelSearch.value = ''
|
||||
|
||||
// 检查缓存
|
||||
const cached = upstreamModelsCache.value.get(providerId)
|
||||
if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
|
||||
upstreamModels.value = cached.models
|
||||
upstreamModelsLoaded.value = true
|
||||
return
|
||||
}
|
||||
|
||||
fetchingUpstreamModels.value = true
|
||||
upstreamModels.value = []
|
||||
|
||||
try {
|
||||
const response = await adminApi.queryProviderModels(providerId)
|
||||
if (response.success && response.data?.models) {
|
||||
upstreamModels.value = response.data.models
|
||||
// 写入缓存
|
||||
upstreamModelsCache.value.set(providerId, {
|
||||
models: response.data.models,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
upstreamModelsLoaded.value = true
|
||||
} else {
|
||||
showError(response.data?.error || '获取模型列表失败', '错误')
|
||||
}
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '获取模型列表失败', '错误')
|
||||
} finally {
|
||||
fetchingUpstreamModels.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 添加单个上游模型
|
||||
function addUpstreamModel(modelId: string) {
|
||||
// 检查是否已存在
|
||||
if (formData.value.aliases.some(a => a.name === modelId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const maxPriority = formData.value.aliases.length > 0
|
||||
? Math.max(...formData.value.aliases.map(a => a.priority))
|
||||
: 0
|
||||
|
||||
formData.value.aliases.push({ name: modelId, priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
// 添加某个分组的所有模型
|
||||
function addAllFromGroup(apiFormat: string) {
|
||||
const group = groupedAvailableUpstreamModels.value.find(g => g.api_format === apiFormat)
|
||||
if (!group) return
|
||||
|
||||
let maxPriority = formData.value.aliases.length > 0
|
||||
? Math.max(...formData.value.aliases.map(a => a.priority))
|
||||
: 0
|
||||
|
||||
for (const model of group.models) {
|
||||
// 检查是否已存在
|
||||
if (!formData.value.aliases.some(a => a.name === model.id)) {
|
||||
maxPriority++
|
||||
formData.value.aliases.push({ name: model.id, priority: maxPriority })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新上游模型列表(清除缓存并重新获取)
|
||||
async function refreshUpstreamModels() {
|
||||
if (!props.provider?.id || refreshingUpstreamModels.value) return
|
||||
|
||||
const providerId = props.provider.id
|
||||
refreshingUpstreamModels.value = true
|
||||
|
||||
// 清除缓存
|
||||
upstreamModelsCache.value.delete(providerId)
|
||||
|
||||
try {
|
||||
const response = await adminApi.queryProviderModels(providerId)
|
||||
if (response.success && response.data?.models) {
|
||||
upstreamModels.value = response.data.models
|
||||
// 写入缓存
|
||||
upstreamModelsCache.value.set(providerId, {
|
||||
models: response.data.models,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
} else {
|
||||
showError(response.data?.error || '刷新失败', '错误')
|
||||
}
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '刷新失败', '错误')
|
||||
} finally {
|
||||
refreshingUpstreamModels.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* 上游模型缓存 - 共享缓存,避免重复请求
|
||||
*/
|
||||
import { ref } from 'vue'
|
||||
import { adminApi } from '@/api/admin'
|
||||
import type { UpstreamModel } from '@/api/endpoints/types'
|
||||
|
||||
// 扩展类型,包含可能的额外字段
|
||||
export type { UpstreamModel }
|
||||
|
||||
interface CacheEntry {
|
||||
models: UpstreamModel[]
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
type FetchResult = { models: UpstreamModel[]; error?: string }
|
||||
|
||||
// 全局缓存(模块级别,所有组件共享)
|
||||
const cache = new Map<string, CacheEntry>()
|
||||
const CACHE_TTL = 5 * 60 * 1000 // 5分钟
|
||||
|
||||
// 进行中的请求(用于去重并发请求)
|
||||
const pendingRequests = new Map<string, Promise<FetchResult>>()
|
||||
|
||||
// 请求状态
|
||||
const loadingMap = ref<Map<string, boolean>>(new Map())
|
||||
|
||||
export function useUpstreamModelsCache() {
|
||||
/**
|
||||
* 获取上游模型列表
|
||||
* @param providerId 提供商ID
|
||||
* @param forceRefresh 是否强制刷新
|
||||
* @returns 模型列表或 null(如果请求失败)
|
||||
*/
|
||||
async function fetchModels(
|
||||
providerId: string,
|
||||
forceRefresh = false
|
||||
): Promise<FetchResult> {
|
||||
// 检查缓存
|
||||
if (!forceRefresh) {
|
||||
const cached = cache.get(providerId)
|
||||
if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
|
||||
return { models: cached.models }
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有进行中的请求(非强制刷新时复用)
|
||||
if (!forceRefresh && pendingRequests.has(providerId)) {
|
||||
return pendingRequests.get(providerId)!
|
||||
}
|
||||
|
||||
// 创建新请求
|
||||
const requestPromise = (async (): Promise<FetchResult> => {
|
||||
try {
|
||||
loadingMap.value.set(providerId, true)
|
||||
const response = await adminApi.queryProviderModels(providerId)
|
||||
|
||||
if (response.success && response.data?.models) {
|
||||
// 存入缓存
|
||||
cache.set(providerId, {
|
||||
models: response.data.models,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
return { models: response.data.models }
|
||||
} else {
|
||||
return { models: [], error: response.data?.error || '获取上游模型失败' }
|
||||
}
|
||||
} catch (err: any) {
|
||||
return { models: [], error: err.response?.data?.detail || '获取上游模型失败' }
|
||||
} finally {
|
||||
loadingMap.value.set(providerId, false)
|
||||
pendingRequests.delete(providerId)
|
||||
}
|
||||
})()
|
||||
|
||||
pendingRequests.set(providerId, requestPromise)
|
||||
return requestPromise
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取缓存的模型(不发起请求)
|
||||
*/
|
||||
function getCachedModels(providerId: string): UpstreamModel[] | null {
|
||||
const cached = cache.get(providerId)
|
||||
if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
|
||||
return cached.models
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除指定提供商的缓存
|
||||
*/
|
||||
function clearCache(providerId: string) {
|
||||
cache.delete(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否正在加载
|
||||
*/
|
||||
function isLoading(providerId: string): boolean {
|
||||
return loadingMap.value.get(providerId) || false
|
||||
}
|
||||
|
||||
return {
|
||||
fetchModels,
|
||||
getCachedModels,
|
||||
clearCache,
|
||||
isLoading,
|
||||
loadingMap
|
||||
}
|
||||
}
|
||||
@@ -472,6 +472,7 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import Separator from '@/components/ui/separator.vue'
|
||||
@@ -897,6 +898,16 @@ const providerHeadersWithDiff = computed(() => {
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (props.isOpen) {
|
||||
handleClose()
|
||||
}
|
||||
}, {
|
||||
disableOnInput: true,
|
||||
once: false
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -136,11 +136,20 @@
|
||||
<!-- 分隔线 -->
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 刷新按钮 -->
|
||||
<RefreshButton
|
||||
:loading="loading"
|
||||
@click="$emit('refresh')"
|
||||
/>
|
||||
<!-- 自动刷新按钮 -->
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
:class="autoRefresh ? 'text-primary' : ''"
|
||||
:title="autoRefresh ? '点击关闭自动刷新' : '点击开启自动刷新(每10秒刷新)'"
|
||||
@click="$emit('update:autoRefresh', !autoRefresh)"
|
||||
>
|
||||
<RefreshCcw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="autoRefresh ? 'animate-spin' : ''"
|
||||
/>
|
||||
</Button>
|
||||
</template>
|
||||
|
||||
<Table>
|
||||
@@ -357,14 +366,34 @@
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="text-right py-4 w-[70px]">
|
||||
<!-- pending 状态:只显示增长的总时间 -->
|
||||
<div
|
||||
v-if="record.status === 'pending' || record.status === 'streaming'"
|
||||
v-if="record.status === 'pending'"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
<span class="text-muted-foreground">-</span>
|
||||
<span class="text-primary tabular-nums">
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
</div>
|
||||
<!-- streaming 状态:首字固定 + 总时间增长 -->
|
||||
<div
|
||||
v-else-if="record.status === 'streaming'"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
<span
|
||||
v-if="record.first_byte_time_ms != null"
|
||||
class="tabular-nums"
|
||||
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
|
||||
<span
|
||||
v-else
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
<span class="text-primary tabular-nums">
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
</div>
|
||||
<!-- 已完成状态:首字 + 总耗时 -->
|
||||
<div
|
||||
v-else-if="record.response_time_ms != null"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
@@ -408,6 +437,7 @@ import { ref, computed, onUnmounted, watch } from 'vue'
|
||||
import {
|
||||
TableCard,
|
||||
Badge,
|
||||
Button,
|
||||
Select,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
@@ -420,8 +450,8 @@ import {
|
||||
TableHead,
|
||||
TableCell,
|
||||
Pagination,
|
||||
RefreshButton,
|
||||
} from '@/components/ui'
|
||||
import { RefreshCcw } from 'lucide-vue-next'
|
||||
import { formatTokens, formatCurrency } from '@/utils/format'
|
||||
import { formatDateTime } from '../composables'
|
||||
import { useRowClick } from '@/composables/useRowClick'
|
||||
@@ -453,6 +483,8 @@ const props = defineProps<{
|
||||
pageSize: number
|
||||
totalRecords: number
|
||||
pageSizeOptions: number[]
|
||||
// 自动刷新
|
||||
autoRefresh: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -463,6 +495,7 @@ const emit = defineEmits<{
|
||||
'update:filterStatus': [value: string]
|
||||
'update:currentPage': [value: number]
|
||||
'update:pageSize': [value: number]
|
||||
'update:autoRefresh': [value: boolean]
|
||||
'refresh': []
|
||||
'showDetail': [id: string]
|
||||
}>()
|
||||
|
||||
@@ -403,7 +403,7 @@ function getUsageRecords() {
|
||||
return cachedUsageRecords
|
||||
}
|
||||
|
||||
// Mock 别名数据
|
||||
// Mock 映射数据
|
||||
const MOCK_ALIASES = [
|
||||
{ id: 'alias-001', source_model: 'claude-4-sonnet', target_global_model_id: 'gm-001', target_global_model_name: 'claude-sonnet-4-20250514', target_global_model_display_name: 'Claude Sonnet 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
||||
{ id: 'alias-002', source_model: 'claude-4-opus', target_global_model_id: 'gm-002', target_global_model_name: 'claude-opus-4-20250514', target_global_model_display_name: 'Claude Opus 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
||||
@@ -1682,7 +1682,7 @@ registerDynamicRoute('GET', '/api/admin/models/mappings/:mappingId', async (_con
|
||||
requireAdmin()
|
||||
const alias = MOCK_ALIASES.find(a => a.id === params.mappingId)
|
||||
if (!alias) {
|
||||
throw { response: createMockResponse({ detail: '别名不存在' }, 404) }
|
||||
throw { response: createMockResponse({ detail: '映射不存在' }, 404) }
|
||||
}
|
||||
return createMockResponse(alias)
|
||||
})
|
||||
@@ -1693,7 +1693,7 @@ registerDynamicRoute('PATCH', '/api/admin/models/mappings/:mappingId', async (co
|
||||
requireAdmin()
|
||||
const alias = MOCK_ALIASES.find(a => a.id === params.mappingId)
|
||||
if (!alias) {
|
||||
throw { response: createMockResponse({ detail: '别名不存在' }, 404) }
|
||||
throw { response: createMockResponse({ detail: '映射不存在' }, 404) }
|
||||
}
|
||||
const body = JSON.parse(config.data || '{}')
|
||||
return createMockResponse({ ...alias, ...body, updated_at: new Date().toISOString() })
|
||||
@@ -1705,7 +1705,7 @@ registerDynamicRoute('DELETE', '/api/admin/models/mappings/:mappingId', async (_
|
||||
requireAdmin()
|
||||
const alias = MOCK_ALIASES.find(a => a.id === params.mappingId)
|
||||
if (!alias) {
|
||||
throw { response: createMockResponse({ detail: '别名不存在' }, 404) }
|
||||
throw { response: createMockResponse({ detail: '映射不存在' }, 404) }
|
||||
}
|
||||
return createMockResponse({ message: '删除成功(演示模式)' })
|
||||
})
|
||||
|
||||
@@ -142,32 +142,37 @@ async function resetAffinitySearch() {
|
||||
await fetchAffinityList()
|
||||
}
|
||||
|
||||
async function clearUserCache(identifier: string, displayName?: string) {
|
||||
const target = identifier?.trim()
|
||||
if (!target) {
|
||||
showError('无法识别标识符')
|
||||
async function clearSingleAffinity(item: UserAffinity) {
|
||||
const affinityKey = item.affinity_key?.trim()
|
||||
const endpointId = item.endpoint_id?.trim()
|
||||
const modelId = item.global_model_id?.trim()
|
||||
const apiFormat = item.api_format?.trim()
|
||||
|
||||
if (!affinityKey || !endpointId || !modelId || !apiFormat) {
|
||||
showError('缓存记录信息不完整,无法删除')
|
||||
return
|
||||
}
|
||||
|
||||
const label = displayName || target
|
||||
const label = item.user_api_key_name || affinityKey
|
||||
const modelLabel = item.model_display_name || item.model_name || modelId
|
||||
const confirmed = await showConfirm({
|
||||
title: '确认清除',
|
||||
message: `确定要清除 ${label} 的缓存吗?`,
|
||||
message: `确定要清除 ${label} 在模型 ${modelLabel} 上的缓存亲和性吗?`,
|
||||
confirmText: '确认清除',
|
||||
variant: 'destructive'
|
||||
})
|
||||
|
||||
if (!confirmed) return
|
||||
|
||||
clearingRowAffinityKey.value = target
|
||||
clearingRowAffinityKey.value = affinityKey
|
||||
try {
|
||||
await cacheApi.clearUserCache(target)
|
||||
await cacheApi.clearSingleAffinity(affinityKey, endpointId, modelId, apiFormat)
|
||||
showSuccess('清除成功')
|
||||
await fetchCacheStats()
|
||||
await fetchAffinityList(tableKeyword.value.trim() || undefined)
|
||||
} catch (error) {
|
||||
showError('清除失败')
|
||||
log.error('清除用户缓存失败', error)
|
||||
log.error('清除单条缓存失败', error)
|
||||
} finally {
|
||||
clearingRowAffinityKey.value = null
|
||||
}
|
||||
@@ -618,7 +623,7 @@ onBeforeUnmount(() => {
|
||||
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
|
||||
:disabled="clearingRowAffinityKey === item.affinity_key"
|
||||
title="清除缓存"
|
||||
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
|
||||
@click="clearSingleAffinity(item)"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
@@ -668,7 +673,7 @@ onBeforeUnmount(() => {
|
||||
variant="ghost"
|
||||
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive shrink-0"
|
||||
:disabled="clearingRowAffinityKey === item.affinity_key"
|
||||
@click="clearUserCache(item.affinity_key, item.user_api_key_name || item.affinity_key)"
|
||||
@click="clearSingleAffinity(item)"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
|
||||
@@ -464,6 +464,7 @@
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- 导入配置对话框 -->
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
:page-size="pageSize"
|
||||
:total-records="totalRecords"
|
||||
:page-size-options="pageSizeOptions"
|
||||
:auto-refresh="globalAutoRefresh"
|
||||
@update:selected-period="handlePeriodChange"
|
||||
@update:filter-user="handleFilterUserChange"
|
||||
@update:filter-model="handleFilterModelChange"
|
||||
@@ -72,6 +73,7 @@
|
||||
@update:filter-status="handleFilterStatusChange"
|
||||
@update:current-page="handlePageChange"
|
||||
@update:page-size="handlePageSizeChange"
|
||||
@update:auto-refresh="handleAutoRefreshChange"
|
||||
@refresh="refreshData"
|
||||
@export="exportData"
|
||||
@show-detail="showRequestDetail"
|
||||
@@ -214,7 +216,10 @@ const hasActiveRequests = computed(() => activeRequestIds.value.length > 0)
|
||||
|
||||
// 自动刷新定时器
|
||||
let autoRefreshTimer: ReturnType<typeof setInterval> | null = null
|
||||
const AUTO_REFRESH_INTERVAL = 1000 // 1秒刷新一次
|
||||
let globalAutoRefreshTimer: ReturnType<typeof setInterval> | null = null
|
||||
const AUTO_REFRESH_INTERVAL = 1000 // 1秒刷新一次(用于活跃请求)
|
||||
const GLOBAL_AUTO_REFRESH_INTERVAL = 10000 // 10秒刷新一次(全局自动刷新)
|
||||
const globalAutoRefresh = ref(false) // 全局自动刷新开关
|
||||
|
||||
// 轮询活跃请求状态(轻量级,只更新状态变化的记录)
|
||||
async function pollActiveRequests() {
|
||||
@@ -278,9 +283,34 @@ watch(hasActiveRequests, (hasActive) => {
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
// 启动全局自动刷新
|
||||
function startGlobalAutoRefresh() {
|
||||
if (globalAutoRefreshTimer) return
|
||||
globalAutoRefreshTimer = setInterval(refreshData, GLOBAL_AUTO_REFRESH_INTERVAL)
|
||||
}
|
||||
|
||||
// 停止全局自动刷新
|
||||
function stopGlobalAutoRefresh() {
|
||||
if (globalAutoRefreshTimer) {
|
||||
clearInterval(globalAutoRefreshTimer)
|
||||
globalAutoRefreshTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
// 处理自动刷新开关变化
|
||||
function handleAutoRefreshChange(value: boolean) {
|
||||
globalAutoRefresh.value = value
|
||||
if (value) {
|
||||
startGlobalAutoRefresh()
|
||||
} else {
|
||||
stopGlobalAutoRefresh()
|
||||
}
|
||||
}
|
||||
|
||||
// 组件卸载时清理定时器
|
||||
onUnmounted(() => {
|
||||
stopAutoRefresh()
|
||||
stopGlobalAutoRefresh()
|
||||
})
|
||||
|
||||
// 用户页面的前端分页
|
||||
|
||||
@@ -350,6 +350,7 @@ import {
|
||||
Layers,
|
||||
Image as ImageIcon
|
||||
} from 'lucide-vue-next'
|
||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
@@ -453,6 +454,16 @@ function getFirst1hCachePrice(tieredPricing: TieredPricingConfig | undefined | n
|
||||
if (!tieredPricing?.tiers?.length) return '-'
|
||||
return get1hCachePrice(tieredPricing.tiers[0])
|
||||
}
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (props.open) {
|
||||
handleClose()
|
||||
}
|
||||
}, {
|
||||
disableOnInput: true,
|
||||
once: false
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -246,6 +246,15 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
if "api_key" in update_data:
|
||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||
|
||||
# 特殊处理 max_concurrent:需要区分"未提供"和"显式设置为 null"
|
||||
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
|
||||
if "max_concurrent" in self.key_data.model_fields_set:
|
||||
update_data["max_concurrent"] = self.key_data.max_concurrent
|
||||
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
|
||||
if self.key_data.max_concurrent is None:
|
||||
update_data["learned_max_concurrent"] = None
|
||||
logger.info("Key %s 切换为自适应并发模式", self.key_id)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(key, field, value)
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
@@ -253,7 +262,7 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
||||
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
|
||||
@@ -186,6 +186,30 @@ async def clear_user_cache(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/affinity/{affinity_key}/{endpoint_id}/{model_id}/{api_format}")
|
||||
async def clear_single_affinity(
|
||||
affinity_key: str,
|
||||
endpoint_id: str,
|
||||
model_id: str,
|
||||
api_format: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Clear a single cache affinity entry
|
||||
|
||||
Parameters:
|
||||
- affinity_key: API Key ID
|
||||
- endpoint_id: Endpoint ID
|
||||
- model_id: Model ID (GlobalModel ID)
|
||||
- api_format: API format (claude/openai)
|
||||
"""
|
||||
adapter = AdminClearSingleAffinityAdapter(
|
||||
affinity_key=affinity_key, endpoint_id=endpoint_id, model_id=model_id, api_format=api_format
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def clear_all_cache(
|
||||
request: Request,
|
||||
@@ -655,6 +679,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
"key_name": key.name if key else None,
|
||||
"key_prefix": provider_key_masked,
|
||||
"rate_multiplier": key.rate_multiplier if key else 1.0,
|
||||
"global_model_id": affinity.get("model_name"), # 原始的 global_model_id
|
||||
"model_name": (
|
||||
global_model_map.get(affinity.get("model_name")).name
|
||||
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
|
||||
@@ -817,6 +842,65 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearSingleAffinityAdapter(AdminApiAdapter):
|
||||
affinity_key: str
|
||||
endpoint_id: str
|
||||
model_id: str
|
||||
api_format: str
|
||||
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
|
||||
# 直接获取指定的亲和性记录(无需遍历全部)
|
||||
existing_affinity = await affinity_mgr.get_affinity(
|
||||
self.affinity_key, self.api_format, self.model_id
|
||||
)
|
||||
|
||||
if not existing_affinity:
|
||||
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
|
||||
|
||||
# 验证 endpoint_id 是否匹配
|
||||
if existing_affinity.endpoint_id != self.endpoint_id:
|
||||
raise HTTPException(status_code=404, detail="未找到指定的缓存亲和性记录")
|
||||
|
||||
# 失效单条记录
|
||||
await affinity_mgr.invalidate_affinity(
|
||||
self.affinity_key, self.api_format, self.model_id, endpoint_id=self.endpoint_id
|
||||
)
|
||||
|
||||
# 获取用于日志的信息
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.affinity_key).first()
|
||||
api_key_name = api_key.name if api_key else None
|
||||
|
||||
logger.info(
|
||||
f"已清除单条缓存亲和性: affinity_key={self.affinity_key[:8]}..., "
|
||||
f"endpoint_id={self.endpoint_id[:8]}..., model_id={self.model_id[:8]}..."
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_single",
|
||||
affinity_key=self.affinity_key,
|
||||
endpoint_id=self.endpoint_id,
|
||||
model_id=self.model_id,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除缓存亲和性: {api_key_name or self.affinity_key[:8]}",
|
||||
"affinity_key": self.affinity_key,
|
||||
"endpoint_id": self.endpoint_id,
|
||||
"model_id": self.model_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除单条缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
try:
|
||||
@@ -863,7 +947,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||
from src.config.constants import ConcurrencyDefaults
|
||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||
|
||||
# 获取动态预留管理器的配置
|
||||
@@ -874,7 +958,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
"cache_reservation_ratio": ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
|
||||
"dynamic_reservation": {
|
||||
"enabled": True,
|
||||
"config": reservation_stats["config"],
|
||||
@@ -897,7 +981,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
context.add_audit_metadata(
|
||||
action="cache_config",
|
||||
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
cache_reservation_ratio=ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
@@ -1083,14 +1167,14 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
continue
|
||||
# 检查是否在别名列表中
|
||||
if model.provider_model_aliases:
|
||||
alias_names = [
|
||||
# 检查是否在映射列表中
|
||||
if model.provider_model_mappings:
|
||||
mapping_list = [
|
||||
a.get("name")
|
||||
for a in model.provider_model_aliases
|
||||
for a in model.provider_model_mappings
|
||||
if isinstance(a, dict)
|
||||
]
|
||||
if mapping_name in alias_names:
|
||||
if mapping_name in mapping_list:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
@@ -1152,19 +1236,19 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
try:
|
||||
cached_data = json.loads(cached_str)
|
||||
provider_model_name = cached_data.get("provider_model_name")
|
||||
provider_model_aliases = cached_data.get("provider_model_aliases", [])
|
||||
cached_model_mappings = cached_data.get("provider_model_mappings", [])
|
||||
|
||||
# 获取 Provider 和 GlobalModel 信息
|
||||
provider = provider_map.get(provider_id)
|
||||
global_model = global_model_map.get(global_model_id)
|
||||
|
||||
if provider and global_model:
|
||||
# 提取别名名称
|
||||
alias_names = []
|
||||
if provider_model_aliases:
|
||||
for alias_entry in provider_model_aliases:
|
||||
if isinstance(alias_entry, dict) and alias_entry.get("name"):
|
||||
alias_names.append(alias_entry["name"])
|
||||
# 提取映射名称
|
||||
mapping_names = []
|
||||
if cached_model_mappings:
|
||||
for mapping_entry in cached_model_mappings:
|
||||
if isinstance(mapping_entry, dict) and mapping_entry.get("name"):
|
||||
mapping_names.append(mapping_entry["name"])
|
||||
|
||||
# provider_model_name 为空时跳过
|
||||
if not provider_model_name:
|
||||
@@ -1172,14 +1256,14 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
|
||||
# 只显示有实际映射的条目:
|
||||
# 1. 全局模型名 != Provider 模型名(模型名称映射)
|
||||
# 2. 或者有别名配置
|
||||
# 2. 或者有映射配置
|
||||
has_name_mapping = global_model.name != provider_model_name
|
||||
has_aliases = len(alias_names) > 0
|
||||
has_mappings = len(mapping_names) > 0
|
||||
|
||||
if has_name_mapping or has_aliases:
|
||||
# 构建用于展示的别名列表
|
||||
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
|
||||
display_aliases = alias_names if alias_names else [global_model.name]
|
||||
if has_name_mapping or has_mappings:
|
||||
# 构建用于展示的映射列表
|
||||
# 如果只有名称映射没有额外映射,则用 global_model_name 作为"请求名称"
|
||||
display_mappings = mapping_names if mapping_names else [global_model.name]
|
||||
|
||||
provider_model_mappings.append({
|
||||
"provider_id": provider_id,
|
||||
@@ -1188,7 +1272,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
"global_model_name": global_model.name,
|
||||
"global_model_display_name": global_model.display_name,
|
||||
"provider_model_name": provider_model_name,
|
||||
"aliases": display_aliases,
|
||||
"aliases": display_mappings,
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
"hit_count": hit_count,
|
||||
})
|
||||
|
||||
@@ -11,6 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import get_adapter_class
|
||||
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
@@ -33,142 +35,19 @@ class ModelsQueryRequest(BaseModel):
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
async def _fetch_openai_models(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
api_format: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 OpenAI 格式的模型列表
|
||||
def _get_adapter_for_format(api_format: str):
|
||||
"""根据 API 格式获取对应的 Adapter 类"""
|
||||
# 先检查 Chat Adapter 注册表
|
||||
adapter_class = get_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class
|
||||
|
||||
Returns:
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖 Authorization
|
||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||
headers.update(safe_headers)
|
||||
# 再检查 CLI Adapter 注册表
|
||||
cli_adapter_class = get_cli_adapter_class(api_format)
|
||||
if cli_adapter_class:
|
||||
return cli_adapter_class
|
||||
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
# 记录详细的错误信息
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
async def _fetch_claude_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Claude 格式的模型列表
|
||||
|
||||
Returns:
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
async def _fetch_gemini_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Gemini 格式的模型列表
|
||||
|
||||
Returns:
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
# 兼容 base_url 已包含 /v1beta 的情况
|
||||
base_url_clean = base_url.rstrip("/")
|
||||
if base_url_clean.endswith("/v1beta"):
|
||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||
else:
|
||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url)
|
||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "models" in data:
|
||||
# 转换为统一格式
|
||||
return [
|
||||
{
|
||||
"id": m.get("name", "").replace("models/", ""),
|
||||
"owned_by": "google",
|
||||
"display_name": m.get("displayName", ""),
|
||||
"api_format": api_format,
|
||||
}
|
||||
for m in data["models"]
|
||||
], None
|
||||
return [], None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
@@ -180,10 +59,10 @@ async def query_available_models(
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||
遍历所有活跃端点,根据端点的 API 格式选择正确的 Adapter 进行请求:
|
||||
- OPENAI/OPENAI_CLI: 使用 OpenAIChatAdapter.fetch_models
|
||||
- CLAUDE/CLAUDE_CLI: 使用 ClaudeChatAdapter.fetch_models
|
||||
- GEMINI/GEMINI_CLI: 使用 GeminiChatAdapter.fetch_models
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
@@ -265,37 +144,53 @@ async def query_available_models(
|
||||
base_url = base_url.rstrip("/")
|
||||
api_format = config["api_format"]
|
||||
api_key_value = config["api_key"]
|
||||
extra_headers = config["extra_headers"]
|
||||
extra_headers = config.get("extra_headers")
|
||||
|
||||
try:
|
||||
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||
else:
|
||||
return await _fetch_openai_models(
|
||||
client, base_url, api_key_value, api_format, extra_headers
|
||||
)
|
||||
# 获取对应的 Adapter 类并调用 fetch_models
|
||||
adapter_class = _get_adapter_for_format(api_format)
|
||||
if not adapter_class:
|
||||
return [], f"Unknown API format: {api_format}"
|
||||
models, error = await adapter_class.fetch_models(
|
||||
client, base_url, api_key_value, extra_headers
|
||||
)
|
||||
# 确保所有模型都有 api_format 字段
|
||||
for m in models:
|
||||
if "api_format" not in m:
|
||||
m["api_format"] = api_format
|
||||
return models, error
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||
return [], f"{api_format}: {str(e)}"
|
||||
|
||||
# 限制并发请求数量,避免触发上游速率限制
|
||||
MAX_CONCURRENT_REQUESTS = 5
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
async def fetch_with_semaphore(
|
||||
client: httpx.AsyncClient, config: dict
|
||||
) -> tuple[list, Optional[str]]:
|
||||
async with semaphore:
|
||||
return await fetch_endpoint_models(client, config)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
results = await asyncio.gather(
|
||||
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||
*[fetch_with_semaphore(client, c) for c in endpoint_configs]
|
||||
)
|
||||
for models, error in results:
|
||||
all_models.extend(models)
|
||||
if error:
|
||||
errors.append(error)
|
||||
|
||||
# 按 model id 去重(保留第一个)
|
||||
seen_ids: set[str] = set()
|
||||
# 按 model id + api_format 去重(保留第一个)
|
||||
seen_keys: set[str] = set()
|
||||
unique_models: list = []
|
||||
for model in all_models:
|
||||
model_id = model.get("id")
|
||||
if model_id and model_id not in seen_ids:
|
||||
seen_ids.add(model_id)
|
||||
api_format = model.get("api_format", "")
|
||||
unique_key = f"{model_id}:{api_format}"
|
||||
if model_id and unique_key not in seen_keys:
|
||||
seen_keys.add(unique_key)
|
||||
unique_models.append(model)
|
||||
|
||||
error = "; ".join(errors) if errors else None
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.models_service import invalidate_models_list_cache
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
@@ -21,16 +22,18 @@ from src.models.api import (
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignModelsToProviderRequest,
|
||||
BatchAssignModelsToProviderResponse,
|
||||
ImportFromUpstreamRequest,
|
||||
ImportFromUpstreamResponse,
|
||||
ImportFromUpstreamSuccessItem,
|
||||
ImportFromUpstreamErrorItem,
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(tags=["Model Management"])
|
||||
@@ -157,6 +160,28 @@ async def batch_assign_global_models_to_provider(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider_id}/import-from-upstream",
|
||||
response_model=ImportFromUpstreamResponse,
|
||||
)
|
||||
async def import_models_from_upstream(
|
||||
provider_id: str,
|
||||
payload: ImportFromUpstreamRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ImportFromUpstreamResponse:
|
||||
"""
|
||||
从上游提供商导入模型
|
||||
|
||||
流程:
|
||||
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
|
||||
2. 如不存在,自动创建新的 GlobalModel(使用默认配置)
|
||||
3. 创建 Model 关联到当前 Provider
|
||||
"""
|
||||
adapter = AdminImportFromUpstreamAdapter(provider_id=provider_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@@ -419,4 +444,135 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
||||
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
if success:
|
||||
await invalidate_models_list_cache()
|
||||
|
||||
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
"""从上游提供商导入模型"""
|
||||
|
||||
provider_id: str
|
||||
payload: ImportFromUpstreamRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
success: list[ImportFromUpstreamSuccessItem] = []
|
||||
errors: list[ImportFromUpstreamErrorItem] = []
|
||||
|
||||
# 默认阶梯计费配置(免费)
|
||||
default_tiered_pricing = {
|
||||
"tiers": [
|
||||
{
|
||||
"up_to": None,
|
||||
"input_price_per_1m": 0.0,
|
||||
"output_price_per_1m": 0.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for model_id in self.payload.model_ids:
|
||||
# 输入验证:检查 model_id 长度
|
||||
if not model_id or len(model_id) > 100:
|
||||
errors.append(
|
||||
ImportFromUpstreamErrorItem(
|
||||
model_id=model_id[:50] + "..." if model_id and len(model_id) > 50 else model_id or "<empty>",
|
||||
error="Invalid model_id: must be 1-100 characters",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用 savepoint 确保单个模型导入的原子性
|
||||
savepoint = db.begin_nested()
|
||||
try:
|
||||
# 1. 检查是否已存在同名的 GlobalModel
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
|
||||
)
|
||||
created_global_model = False
|
||||
|
||||
if not global_model:
|
||||
# 2. 创建新的 GlobalModel
|
||||
global_model = GlobalModel(
|
||||
name=model_id,
|
||||
display_name=model_id,
|
||||
default_tiered_pricing=default_tiered_pricing,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(global_model)
|
||||
db.flush()
|
||||
created_global_model = True
|
||||
logger.info(
|
||||
f"Created new GlobalModel: {model_id} during upstream import"
|
||||
)
|
||||
|
||||
# 3. 检查是否已存在关联
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == self.provider_id,
|
||||
Model.global_model_id == global_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
# 已存在关联,提交 savepoint 并记录成功
|
||||
savepoint.commit()
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
provider_model_id=existing.id,
|
||||
created_global_model=created_global_model,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# 4. 创建新的 Model 记录
|
||||
new_model = Model(
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=global_model.id,
|
||||
provider_model_name=global_model.name,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(new_model)
|
||||
db.flush()
|
||||
|
||||
# 提交 savepoint
|
||||
savepoint.commit()
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
provider_model_id=new_model.id,
|
||||
created_global_model=created_global_model,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# 回滚到 savepoint
|
||||
savepoint.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing model {model_id}: {e}")
|
||||
errors.append(ImportFromUpstreamErrorItem(model_id=model_id, error=str(e)))
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
if success:
|
||||
await invalidate_models_list_cache()
|
||||
|
||||
return ImportFromUpstreamResponse(success=success, errors=errors)
|
||||
|
||||
@@ -436,7 +436,7 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
{
|
||||
"global_model_name": global_model.name if global_model else None,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": model.provider_model_aliases,
|
||||
"provider_model_mappings": model.provider_model_mappings,
|
||||
"price_per_request": model.price_per_request,
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
@@ -790,8 +790,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
existing_model.global_model_id = global_model_id
|
||||
existing_model.provider_model_aliases = model_data.get(
|
||||
"provider_model_aliases"
|
||||
existing_model.provider_model_mappings = model_data.get(
|
||||
"provider_model_mappings"
|
||||
)
|
||||
existing_model.price_per_request = model_data.get(
|
||||
"price_per_request"
|
||||
@@ -824,8 +824,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
provider_id=provider_id,
|
||||
global_model_id=global_model_id,
|
||||
provider_model_name=model_data["provider_model_name"],
|
||||
provider_model_aliases=model_data.get(
|
||||
"provider_model_aliases"
|
||||
provider_model_mappings=model_data.get(
|
||||
"provider_model_mappings"
|
||||
),
|
||||
price_per_request=model_data.get("price_per_request"),
|
||||
tiered_pricing=model_data.get("tiered_pricing"),
|
||||
|
||||
@@ -55,6 +55,23 @@ async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"])
|
||||
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
|
||||
|
||||
|
||||
async def invalidate_models_list_cache() -> None:
|
||||
"""
|
||||
清除所有 /v1/models 列表缓存
|
||||
|
||||
在模型创建、更新、删除时调用,确保模型列表实时更新
|
||||
"""
|
||||
# 清除所有格式的缓存
|
||||
all_formats = ["CLAUDE", "OPENAI", "GEMINI"]
|
||||
for fmt in all_formats:
|
||||
cache_key = f"{_CACHE_KEY_PREFIX}:{fmt}"
|
||||
try:
|
||||
await CacheService.delete(cache_key)
|
||||
logger.debug(f"[ModelsService] 已清除缓存: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelsService] 清除缓存失败 {cache_key}: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""统一的模型信息结构"""
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import UserRole
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, Provider, RequestCandidate, StatsDaily, Usage
|
||||
from src.models.database import ApiKey, Provider, RequestCandidate, StatsDaily, StatsDailyModel, Usage
|
||||
from src.models.database import User as DBUser
|
||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||
from src.utils.cache_decorator import cache_result
|
||||
@@ -893,69 +893,172 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
})
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# ==================== 模型统计(仍需实时查询)====================
|
||||
model_query = db.query(Usage)
|
||||
if not is_admin:
|
||||
model_query = model_query.filter(Usage.user_id == user.id)
|
||||
model_query = model_query.filter(
|
||||
and_(Usage.created_at >= start_date, Usage.created_at <= end_date)
|
||||
)
|
||||
|
||||
model_stats = (
|
||||
model_query.with_entities(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
# ==================== 模型统计 ====================
|
||||
if is_admin:
|
||||
# 管理员:使用预聚合数据 + 今日实时数据
|
||||
# 历史数据从 stats_daily_model 获取
|
||||
historical_model_stats = (
|
||||
db.query(StatsDailyModel)
|
||||
.filter(and_(StatsDailyModel.date >= start_date, StatsDailyModel.date < today))
|
||||
.all()
|
||||
)
|
||||
.group_by(Usage.model)
|
||||
.order_by(func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
model_summary = [
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
"avg_response_time": (
|
||||
float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0
|
||||
),
|
||||
"cost_per_request": float(stat.cost or 0) / max(stat.requests or 1, 1),
|
||||
"tokens_per_request": int(stat.tokens or 0) / max(stat.requests or 1, 1),
|
||||
}
|
||||
for stat in model_stats
|
||||
]
|
||||
# 按模型汇总历史数据
|
||||
model_agg: dict = {}
|
||||
daily_breakdown: dict = {}
|
||||
|
||||
daily_model_stats = (
|
||||
model_query.with_entities(
|
||||
func.date(Usage.created_at).label("date"),
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
for stat in historical_model_stats:
|
||||
model = stat.model
|
||||
if model not in model_agg:
|
||||
model_agg[model] = {
|
||||
"requests": 0, "tokens": 0, "cost": 0.0,
|
||||
"total_response_time": 0.0, "response_count": 0
|
||||
}
|
||||
model_agg[model]["requests"] += stat.total_requests
|
||||
tokens = (stat.input_tokens + stat.output_tokens +
|
||||
stat.cache_creation_tokens + stat.cache_read_tokens)
|
||||
model_agg[model]["tokens"] += tokens
|
||||
model_agg[model]["cost"] += stat.total_cost
|
||||
if stat.avg_response_time_ms is not None:
|
||||
model_agg[model]["total_response_time"] += stat.avg_response_time_ms * stat.total_requests
|
||||
model_agg[model]["response_count"] += stat.total_requests
|
||||
|
||||
# 按日期分组
|
||||
if stat.date.tzinfo is None:
|
||||
date_utc = stat.date.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
date_utc = stat.date.astimezone(timezone.utc)
|
||||
date_str = date_utc.astimezone(app_tz).date().isoformat()
|
||||
|
||||
daily_breakdown.setdefault(date_str, []).append({
|
||||
"model": model,
|
||||
"requests": stat.total_requests,
|
||||
"tokens": tokens,
|
||||
"cost": stat.total_cost,
|
||||
})
|
||||
|
||||
# 今日实时模型统计
|
||||
today_model_stats = (
|
||||
db.query(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.filter(Usage.created_at >= today)
|
||||
.group_by(Usage.model)
|
||||
.all()
|
||||
)
|
||||
.group_by(func.date(Usage.created_at), Usage.model)
|
||||
.order_by(func.date(Usage.created_at).desc(), func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
breakdown = {}
|
||||
for stat in daily_model_stats:
|
||||
date_str = stat.date.isoformat()
|
||||
breakdown.setdefault(date_str, []).append(
|
||||
today_str = today_local.date().isoformat()
|
||||
for stat in today_model_stats:
|
||||
model = stat.model
|
||||
if model not in model_agg:
|
||||
model_agg[model] = {
|
||||
"requests": 0, "tokens": 0, "cost": 0.0,
|
||||
"total_response_time": 0.0, "response_count": 0
|
||||
}
|
||||
model_agg[model]["requests"] += stat.requests or 0
|
||||
model_agg[model]["tokens"] += int(stat.tokens or 0)
|
||||
model_agg[model]["cost"] += float(stat.cost or 0)
|
||||
if stat.avg_response_time is not None:
|
||||
model_agg[model]["total_response_time"] += float(stat.avg_response_time) * (stat.requests or 0)
|
||||
model_agg[model]["response_count"] += stat.requests or 0
|
||||
|
||||
# 今日 breakdown
|
||||
daily_breakdown.setdefault(today_str, []).append({
|
||||
"model": model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
})
|
||||
|
||||
# 构建 model_summary
|
||||
model_summary = []
|
||||
for model, agg in model_agg.items():
|
||||
avg_rt = (agg["total_response_time"] / agg["response_count"] / 1000.0
|
||||
if agg["response_count"] > 0 else 0)
|
||||
model_summary.append({
|
||||
"model": model,
|
||||
"requests": agg["requests"],
|
||||
"tokens": agg["tokens"],
|
||||
"cost": agg["cost"],
|
||||
"avg_response_time": avg_rt,
|
||||
"cost_per_request": agg["cost"] / max(agg["requests"], 1),
|
||||
"tokens_per_request": agg["tokens"] / max(agg["requests"], 1),
|
||||
})
|
||||
model_summary.sort(key=lambda x: x["cost"], reverse=True)
|
||||
|
||||
# 填充 model_breakdown
|
||||
for item in formatted:
|
||||
item["model_breakdown"] = daily_breakdown.get(item["date"], [])
|
||||
|
||||
else:
|
||||
# 普通用户:实时查询(数据量较小)
|
||||
model_query = db.query(Usage).filter(
|
||||
and_(
|
||||
Usage.user_id == user.id,
|
||||
Usage.created_at >= start_date,
|
||||
Usage.created_at <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
model_stats = (
|
||||
model_query.with_entities(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.group_by(Usage.model)
|
||||
.order_by(func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
model_summary = [
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
"avg_response_time": (
|
||||
float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0
|
||||
),
|
||||
"cost_per_request": float(stat.cost or 0) / max(stat.requests or 1, 1),
|
||||
"tokens_per_request": int(stat.tokens or 0) / max(stat.requests or 1, 1),
|
||||
}
|
||||
for stat in model_stats
|
||||
]
|
||||
|
||||
daily_model_stats = (
|
||||
model_query.with_entities(
|
||||
func.date(Usage.created_at).label("date"),
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
)
|
||||
.group_by(func.date(Usage.created_at), Usage.model)
|
||||
.order_by(func.date(Usage.created_at).desc(), func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for item in formatted:
|
||||
item["model_breakdown"] = breakdown.get(item["date"], [])
|
||||
breakdown = {}
|
||||
for stat in daily_model_stats:
|
||||
date_str = stat.date.isoformat()
|
||||
breakdown.setdefault(date_str, []).append(
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
}
|
||||
)
|
||||
|
||||
for item in formatted:
|
||||
item["model_breakdown"] = breakdown.get(item["date"], [])
|
||||
|
||||
return {
|
||||
"daily_stats": formatted,
|
||||
|
||||
@@ -376,6 +376,9 @@ class BaseMessageHandler:
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||
|
||||
Args:
|
||||
request_id: 请求 ID,如果不传则使用 self.request_id
|
||||
"""
|
||||
@@ -407,6 +410,9 @@ class BaseMessageHandler:
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||
"""
|
||||
|
||||
@@ -19,8 +19,9 @@ Chat Adapter 通用基类
|
||||
import time
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -620,6 +621,39 @@ class ChatAdapterBase(ApiAdapter):
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""
|
||||
查询上游 API 支持的模型列表
|
||||
|
||||
这是 Aether 内部发起的请求(非用户透传),用于:
|
||||
- 管理后台查询提供商支持的模型
|
||||
- 自动发现可用模型
|
||||
|
||||
Args:
|
||||
client: httpx 异步客户端
|
||||
base_url: API 基础 URL
|
||||
api_key: API 密钥(已解密)
|
||||
extra_headers: 端点配置的额外请求头
|
||||
|
||||
Returns:
|
||||
(models, error): 模型列表和错误信息
|
||||
- models: 模型信息列表,每个模型至少包含 id 字段
|
||||
- error: 错误信息,成功时为 None
|
||||
"""
|
||||
# 默认实现返回空列表,子类应覆盖
|
||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||
|
||||
@@ -260,9 +260,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
# 使用 select_provider_model_name 支持映射功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一映射
|
||||
# 传入 api_format 用于过滤适用的映射作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
@@ -484,9 +484,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate)
|
||||
byte_iterator = stream_response.aiter_bytes()
|
||||
|
||||
# 预读检测嵌套错误
|
||||
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
||||
@@ -639,6 +638,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||
f"模型={model} -> {mapped_model or '无映射'}")
|
||||
logger.debug(f" [{self.request_id}] 请求URL: {url}")
|
||||
logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}")
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
@@ -662,10 +663,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
response_headers=response_headers,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
raise ProviderNotAvailableException(f"提供商服务不可用: {provider.name}")
|
||||
elif resp.status_code != 200:
|
||||
# 记录响应体以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}")
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}"
|
||||
f"提供商服务不可用: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
# 记录非200响应以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}")
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
)
|
||||
|
||||
response_json = resp.json()
|
||||
|
||||
@@ -17,8 +17,9 @@ CLI Adapter 通用基类
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -580,6 +581,39 @@ class CliAdapterBase(ApiAdapter):
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""
|
||||
查询上游 API 支持的模型列表
|
||||
|
||||
这是 Aether 内部发起的请求(非用户透传),用于:
|
||||
- 管理后台查询提供商支持的模型
|
||||
- 自动发现可用模型
|
||||
|
||||
Args:
|
||||
client: httpx 异步客户端
|
||||
base_url: API 基础 URL
|
||||
api_key: API 密钥(已解密)
|
||||
extra_headers: 端点配置的额外请求头
|
||||
|
||||
Returns:
|
||||
(models, error): 模型列表和错误信息
|
||||
- models: 模型信息列表,每个模型至少包含 id 字段
|
||||
- error: 错误信息,成功时为 None
|
||||
"""
|
||||
# 默认实现返回空列表,子类应覆盖
|
||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||
|
||||
@@ -57,8 +57,10 @@ from src.models.database import (
|
||||
ProviderEndpoint,
|
||||
User,
|
||||
)
|
||||
from src.config.settings import config
|
||||
from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||
|
||||
|
||||
class CliMessageHandlerBase(BaseMessageHandler):
|
||||
@@ -136,7 +138,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
|
||||
3. 使用 provider_model_name / provider_model_mappings 选择最终名称
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
@@ -153,9 +155,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
# 使用 select_provider_model_name 支持模型映射功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一映射
|
||||
# 传入 api_format 用于过滤适用的映射作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
@@ -400,7 +402,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
|
||||
|
||||
# 获取模型映射(别名/映射 → 实际模型名)
|
||||
# 获取模型映射(映射名称 → 实际模型名)
|
||||
mapped_model = await self._get_mapped_model(
|
||||
source_model=ctx.model,
|
||||
provider_id=str(provider.id),
|
||||
@@ -474,8 +476,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate)
|
||||
byte_iterator = stream_response.aiter_bytes()
|
||||
|
||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
||||
@@ -529,7 +531,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
async for chunk in stream_response.aiter_raw():
|
||||
async for chunk in stream_response.aiter_bytes():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||
|
||||
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||
|
||||
Args:
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||
"""
|
||||
prefetched_chunks: list = []
|
||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
else:
|
||||
provider_parser = self.parser
|
||||
|
||||
async for chunk in byte_iterator:
|
||||
# 使用共享的 TTFB 超时函数读取首字节
|
||||
ttfb_timeout = config.stream_first_byte_timeout
|
||||
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||
byte_iterator,
|
||||
timeout=ttfb_timeout,
|
||||
request_id=self.request_id,
|
||||
provider_name=str(provider.name),
|
||||
)
|
||||
prefetched_chunks.append(first_chunk)
|
||||
buffer += first_chunk
|
||||
|
||||
# 继续读取剩余的预读数据
|
||||
async for chunk in aiter:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
# 重新抛出嵌套错误
|
||||
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
# 网络 I/O 异常:记录警告,可能需要重试
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||
logger.error(
|
||||
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
return prefetched_chunks
|
||||
|
||||
@@ -1114,8 +1140,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
async for chunk in stream_generator:
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "Client disconnected"
|
||||
# 如果响应已完成,不标记为失败
|
||||
if not ctx.has_completion:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "Client disconnected"
|
||||
raise
|
||||
except httpx.TimeoutException as e:
|
||||
ctx.status_code = 504
|
||||
@@ -1380,7 +1408,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
provider_name = str(provider.name)
|
||||
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||
|
||||
# 获取模型映射(别名/映射 → 实际模型名)
|
||||
# 获取模型映射(映射名称 → 实际模型名)
|
||||
mapped_model = await self._get_mapped_model(
|
||||
source_model=model,
|
||||
provider_id=str(provider.id),
|
||||
|
||||
274
src/api/handlers/base/content_extractors.py
Normal file
274
src/api/handlers/base/content_extractors.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
流式内容提取器 - 策略模式实现
|
||||
|
||||
为不同 API 格式(OpenAI、Claude、Gemini)提供内容提取和 chunk 构造的抽象。
|
||||
StreamSmoother 使用这些提取器来处理不同格式的 SSE 事件。
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ContentExtractor(ABC):
|
||||
"""
|
||||
流式内容提取器抽象基类
|
||||
|
||||
定义从 SSE 事件中提取文本内容和构造新 chunk 的接口。
|
||||
每种 API 格式(OpenAI、Claude、Gemini)需要实现自己的提取器。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_content(self, data: dict) -> Optional[str]:
|
||||
"""
|
||||
从 SSE 数据中提取可拆分的文本内容
|
||||
|
||||
Args:
|
||||
data: 解析后的 JSON 数据
|
||||
|
||||
Returns:
|
||||
提取的文本内容,如果无法提取则返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_chunk(
|
||||
self,
|
||||
original_data: dict,
|
||||
new_content: str,
|
||||
event_type: str = "",
|
||||
is_first: bool = False,
|
||||
) -> bytes:
|
||||
"""
|
||||
使用新内容构造 SSE chunk
|
||||
|
||||
Args:
|
||||
original_data: 原始 JSON 数据
|
||||
new_content: 新的文本内容
|
||||
event_type: SSE 事件类型(某些格式需要)
|
||||
is_first: 是否是第一个 chunk(用于保留 role 等字段)
|
||||
|
||||
Returns:
|
||||
编码后的 SSE 字节数据
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIContentExtractor(ContentExtractor):
|
||||
"""
|
||||
OpenAI 格式内容提取器
|
||||
|
||||
处理 OpenAI Chat Completions API 的流式响应格式:
|
||||
- 数据结构: choices[0].delta.content
|
||||
- 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
|
||||
"""
|
||||
|
||||
def extract_content(self, data: dict) -> Optional[str]:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or len(choices) != 1:
|
||||
return None
|
||||
|
||||
first_choice = choices[0]
|
||||
if not isinstance(first_choice, dict):
|
||||
return None
|
||||
|
||||
delta = first_choice.get("delta")
|
||||
if not isinstance(delta, dict):
|
||||
return None
|
||||
|
||||
content = delta.get("content")
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
|
||||
# 只有 delta 仅包含 role/content 时才允许拆分
|
||||
# 避免破坏 tool_calls、function_call 等复杂结构
|
||||
allowed_keys = {"role", "content"}
|
||||
if not all(key in allowed_keys for key in delta.keys()):
|
||||
return None
|
||||
|
||||
return content
|
||||
|
||||
def create_chunk(
|
||||
self,
|
||||
original_data: dict,
|
||||
new_content: str,
|
||||
event_type: str = "",
|
||||
is_first: bool = False,
|
||||
) -> bytes:
|
||||
new_data = original_data.copy()
|
||||
|
||||
if "choices" in new_data and new_data["choices"]:
|
||||
new_choices = []
|
||||
for choice in new_data["choices"]:
|
||||
new_choice = choice.copy()
|
||||
if "delta" in new_choice:
|
||||
new_delta = {}
|
||||
# 只有第一个 chunk 保留 role
|
||||
if is_first and "role" in new_choice["delta"]:
|
||||
new_delta["role"] = new_choice["delta"]["role"]
|
||||
new_delta["content"] = new_content
|
||||
new_choice["delta"] = new_delta
|
||||
new_choices.append(new_choice)
|
||||
new_data["choices"] = new_choices
|
||||
|
||||
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
class ClaudeContentExtractor(ContentExtractor):
|
||||
"""
|
||||
Claude 格式内容提取器
|
||||
|
||||
处理 Claude Messages API 的流式响应格式:
|
||||
- 事件类型: content_block_delta
|
||||
- 数据结构: delta.type=text_delta, delta.text
|
||||
"""
|
||||
|
||||
def extract_content(self, data: dict) -> Optional[str]:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
# 检查事件类型
|
||||
if data.get("type") != "content_block_delta":
|
||||
return None
|
||||
|
||||
delta = data.get("delta", {})
|
||||
if not isinstance(delta, dict):
|
||||
return None
|
||||
|
||||
# 检查 delta 类型
|
||||
if delta.get("type") != "text_delta":
|
||||
return None
|
||||
|
||||
text = delta.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
|
||||
return text
|
||||
|
||||
def create_chunk(
|
||||
self,
|
||||
original_data: dict,
|
||||
new_content: str,
|
||||
event_type: str = "",
|
||||
is_first: bool = False,
|
||||
) -> bytes:
|
||||
new_data = original_data.copy()
|
||||
|
||||
if "delta" in new_data:
|
||||
new_delta = new_data["delta"].copy()
|
||||
new_delta["text"] = new_content
|
||||
new_data["delta"] = new_delta
|
||||
|
||||
# Claude 格式需要 event: 前缀
|
||||
event_name = event_type or "content_block_delta"
|
||||
return f"event: {event_name}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
class GeminiContentExtractor(ContentExtractor):
|
||||
"""
|
||||
Gemini 格式内容提取器
|
||||
|
||||
处理 Gemini API 的流式响应格式:
|
||||
- 数据结构: candidates[0].content.parts[0].text
|
||||
- 只有纯文本块才拆分
|
||||
"""
|
||||
|
||||
def extract_content(self, data: dict) -> Optional[str]:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
candidates = data.get("candidates")
|
||||
if not isinstance(candidates, list) or len(candidates) != 1:
|
||||
return None
|
||||
|
||||
first_candidate = candidates[0]
|
||||
if not isinstance(first_candidate, dict):
|
||||
return None
|
||||
|
||||
content = first_candidate.get("content", {})
|
||||
if not isinstance(content, dict):
|
||||
return None
|
||||
|
||||
parts = content.get("parts", [])
|
||||
if not isinstance(parts, list) or len(parts) != 1:
|
||||
return None
|
||||
|
||||
first_part = parts[0]
|
||||
if not isinstance(first_part, dict):
|
||||
return None
|
||||
|
||||
text = first_part.get("text")
|
||||
# 只有纯文本块(只有 text 字段)才拆分
|
||||
if not isinstance(text, str) or len(first_part) != 1:
|
||||
return None
|
||||
|
||||
return text
|
||||
|
||||
def create_chunk(
|
||||
self,
|
||||
original_data: dict,
|
||||
new_content: str,
|
||||
event_type: str = "",
|
||||
is_first: bool = False,
|
||||
) -> bytes:
|
||||
new_data = copy.deepcopy(original_data)
|
||||
|
||||
if "candidates" in new_data and new_data["candidates"]:
|
||||
first_candidate = new_data["candidates"][0]
|
||||
if "content" in first_candidate:
|
||||
content = first_candidate["content"]
|
||||
if "parts" in content and content["parts"]:
|
||||
content["parts"][0]["text"] = new_content
|
||||
|
||||
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
# 提取器注册表
|
||||
_EXTRACTORS: dict[str, type[ContentExtractor]] = {
|
||||
"openai": OpenAIContentExtractor,
|
||||
"claude": ClaudeContentExtractor,
|
||||
"gemini": GeminiContentExtractor,
|
||||
}
|
||||
|
||||
|
||||
def get_extractor(format_name: str) -> Optional[ContentExtractor]:
|
||||
"""
|
||||
根据格式名获取对应的内容提取器实例
|
||||
|
||||
Args:
|
||||
format_name: 格式名称(openai, claude, gemini)
|
||||
|
||||
Returns:
|
||||
对应的提取器实例,如果格式不支持则返回 None
|
||||
"""
|
||||
extractor_class = _EXTRACTORS.get(format_name.lower())
|
||||
if extractor_class:
|
||||
return extractor_class()
|
||||
return None
|
||||
|
||||
|
||||
def register_extractor(format_name: str, extractor_class: type[ContentExtractor]) -> None:
|
||||
"""
|
||||
注册新的内容提取器
|
||||
|
||||
Args:
|
||||
format_name: 格式名称
|
||||
extractor_class: 提取器类
|
||||
"""
|
||||
_EXTRACTORS[format_name.lower()] = extractor_class
|
||||
|
||||
|
||||
def get_extractor_formats() -> list[str]:
|
||||
"""
|
||||
获取所有已注册的格式名称列表
|
||||
|
||||
Returns:
|
||||
格式名称列表
|
||||
"""
|
||||
return list(_EXTRACTORS.keys())
|
||||
@@ -6,30 +6,47 @@
|
||||
2. 响应流生成
|
||||
3. 预读和嵌套错误检测
|
||||
4. 客户端断开检测
|
||||
5. 流式平滑输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Callable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.api.handlers.base.content_extractors import (
|
||||
ContentExtractor,
|
||||
get_extractor,
|
||||
get_extractor_formats,
|
||||
)
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.core.exceptions import EmbeddedErrorException
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderEndpoint
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSmoothingConfig:
|
||||
"""流式平滑输出配置"""
|
||||
|
||||
enabled: bool = False
|
||||
chunk_size: int = 20
|
||||
delay_ms: int = 8
|
||||
|
||||
|
||||
class StreamProcessor:
|
||||
"""
|
||||
流式响应处理器
|
||||
|
||||
负责处理 SSE 流的解析、错误检测和响应生成。
|
||||
负责处理 SSE 流的解析、错误检测、响应生成和平滑输出。
|
||||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||||
"""
|
||||
|
||||
@@ -40,6 +57,7 @@ class StreamProcessor:
|
||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||
*,
|
||||
collect_text: bool = False,
|
||||
smoothing_config: Optional[StreamSmoothingConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化流处理器
|
||||
@@ -48,11 +66,17 @@ class StreamProcessor:
|
||||
request_id: 请求 ID(用于日志)
|
||||
default_parser: 默认响应解析器
|
||||
on_streaming_start: 流开始时的回调(用于更新状态)
|
||||
collect_text: 是否收集文本内容
|
||||
smoothing_config: 流式平滑输出配置
|
||||
"""
|
||||
self.request_id = request_id
|
||||
self.default_parser = default_parser
|
||||
self.on_streaming_start = on_streaming_start
|
||||
self.collect_text = collect_text
|
||||
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
|
||||
|
||||
# 内容提取器缓存
|
||||
self._extractors: dict[str, ContentExtractor] = {}
|
||||
|
||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||
"""
|
||||
@@ -127,6 +151,13 @@ class StreamProcessor:
|
||||
if event_type in ("response.completed", "message_stop"):
|
||||
ctx.has_completion = True
|
||||
|
||||
# 检查 OpenAI 格式的 finish_reason
|
||||
choices = data.get("choices", [])
|
||||
if choices and isinstance(choices, list) and len(choices) > 0:
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
if finish_reason is not None:
|
||||
ctx.has_completion = True
|
||||
|
||||
async def prefetch_and_check_error(
|
||||
self,
|
||||
byte_iterator: Any,
|
||||
@@ -141,6 +172,8 @@ class StreamProcessor:
|
||||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||
|
||||
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||
|
||||
Args:
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
@@ -153,6 +186,7 @@ class StreamProcessor:
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||
"""
|
||||
prefetched_chunks: list = []
|
||||
parser = self.get_parser_for_provider(ctx)
|
||||
@@ -163,7 +197,19 @@ class StreamProcessor:
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
try:
|
||||
async for chunk in byte_iterator:
|
||||
# 使用共享的 TTFB 超时函数读取首字节
|
||||
ttfb_timeout = config.stream_first_byte_timeout
|
||||
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||
byte_iterator,
|
||||
timeout=ttfb_timeout,
|
||||
request_id=self.request_id,
|
||||
provider_name=str(provider.name),
|
||||
)
|
||||
prefetched_chunks.append(first_chunk)
|
||||
buffer += first_chunk
|
||||
|
||||
# 继续读取剩余的预读数据
|
||||
async for chunk in aiter:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
@@ -233,10 +279,21 @@ class StreamProcessor:
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
except (EmbeddedErrorException, ProviderTimeoutException):
|
||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
# 网络 I/O <20><><EFBFBD>常:记录警告,可能需要重试
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||
logger.error(
|
||||
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
return prefetched_chunks
|
||||
|
||||
@@ -369,7 +426,7 @@ class StreamProcessor:
|
||||
sse_parser: SSE 解析器
|
||||
line: 原始行数据
|
||||
"""
|
||||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||
normalized_line = line.rstrip("\r\n")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
@@ -400,32 +457,201 @@ class StreamProcessor:
|
||||
响应数据块
|
||||
"""
|
||||
try:
|
||||
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
||||
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
||||
next_disconnect_check_at = 0.0
|
||||
disconnect_check_interval_s = 0.25
|
||||
# 使用后台任务检查断连,完全不阻塞流式传输
|
||||
disconnected = False
|
||||
|
||||
async for chunk in stream_generator:
|
||||
now = time.monotonic()
|
||||
if now >= next_disconnect_check_at:
|
||||
next_disconnect_check_at = now + disconnect_check_interval_s
|
||||
async def check_disconnect_background() -> None:
|
||||
nonlocal disconnected
|
||||
while not disconnected and not ctx.has_completion:
|
||||
await asyncio.sleep(0.5)
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
disconnected = True
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
# 启动后台检查任务
|
||||
check_task = asyncio.create_task(check_disconnect_background())
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if disconnected:
|
||||
# 如果响应已完成,客户端断开不算失败
|
||||
if ctx.has_completion:
|
||||
logger.info(
|
||||
f"ID:{self.request_id} | Client disconnected after completion"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
check_task.cancel()
|
||||
try:
|
||||
await check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
# 如果响应已完成,不标记为失败
|
||||
if not ctx.has_completion:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
raise
|
||||
except Exception as e:
|
||||
ctx.status_code = 500
|
||||
ctx.error_message = str(e)
|
||||
raise
|
||||
|
||||
async def create_smoothed_stream(
|
||||
self,
|
||||
stream_generator: AsyncGenerator[bytes, None],
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
创建平滑输出的流生成器
|
||||
|
||||
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
|
||||
否则直接透传原始流。
|
||||
|
||||
Args:
|
||||
stream_generator: 原始流生成器
|
||||
|
||||
Yields:
|
||||
平滑处理后的响应数据块
|
||||
"""
|
||||
if not self.smoothing_config.enabled:
|
||||
# 未启用平滑输出,直接透传
|
||||
async for chunk in stream_generator:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# 启用平滑输出
|
||||
buffer = b""
|
||||
is_first_content = True
|
||||
|
||||
async for chunk in stream_generator:
|
||||
buffer += chunk
|
||||
|
||||
# 按双换行分割 SSE 事件(标准 SSE 格式)
|
||||
while b"\n\n" in buffer:
|
||||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||
event_str = event_block.decode("utf-8", errors="replace")
|
||||
|
||||
# 解析事件块
|
||||
lines = event_str.strip().split("\n")
|
||||
data_str = None
|
||||
event_type = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip("\r")
|
||||
if line.startswith("event: "):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 没有 data 行,直接透传
|
||||
if data_str is None:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# [DONE] 直接透传
|
||||
if data_str.strip() == "[DONE]":
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# 检测格式并提取内容
|
||||
content, extractor = self._detect_format_and_extract(data)
|
||||
|
||||
# 只有内容长度大于 1 才需要平滑处理
|
||||
if content and len(content) > 1 and extractor:
|
||||
# 获取配置的延迟
|
||||
delay_seconds = self._calculate_delay()
|
||||
|
||||
# 拆分内容
|
||||
content_chunks = self._split_content(content)
|
||||
|
||||
for i, sub_content in enumerate(content_chunks):
|
||||
is_first = is_first_content and i == 0
|
||||
|
||||
# 使用提取器创建新 chunk
|
||||
sse_chunk = extractor.create_chunk(
|
||||
data,
|
||||
sub_content,
|
||||
event_type=event_type,
|
||||
is_first=is_first,
|
||||
)
|
||||
|
||||
yield sse_chunk
|
||||
|
||||
# 除了最后一个块,其他块之间加延迟
|
||||
if i < len(content_chunks) - 1:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
is_first_content = False
|
||||
else:
|
||||
# 不需要拆分,直接透传
|
||||
yield event_block + b"\n\n"
|
||||
if content:
|
||||
is_first_content = False
|
||||
|
||||
# 处理剩余数据
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||
"""获取或创建格式对应的提取器(带缓存)"""
|
||||
if format_name not in self._extractors:
|
||||
extractor = get_extractor(format_name)
|
||||
if extractor:
|
||||
self._extractors[format_name] = extractor
|
||||
return self._extractors.get(format_name)
|
||||
|
||||
def _detect_format_and_extract(
|
||||
self, data: dict
|
||||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||
"""
|
||||
检测数据格式并提取内容
|
||||
|
||||
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
|
||||
|
||||
Returns:
|
||||
(content, extractor): 提取的内容和对应的提取器
|
||||
"""
|
||||
for format_name in get_extractor_formats():
|
||||
extractor = self._get_extractor(format_name)
|
||||
if extractor:
|
||||
content = extractor.extract_content(data)
|
||||
if content is not None:
|
||||
return content, extractor
|
||||
|
||||
return None, None
|
||||
|
||||
def _calculate_delay(self) -> float:
|
||||
"""获取配置的延迟(秒)"""
|
||||
return self.smoothing_config.delay_ms / 1000.0
|
||||
|
||||
def _split_content(self, content: str) -> list[str]:
|
||||
"""
|
||||
按块拆分文本
|
||||
"""
|
||||
chunk_size = self.smoothing_config.chunk_size
|
||||
text_length = len(content)
|
||||
|
||||
if text_length <= chunk_size:
|
||||
return [content]
|
||||
|
||||
# 按块拆分
|
||||
chunks = []
|
||||
for i in range(0, text_length, chunk_size):
|
||||
chunks.append(content[i : i + chunk_size])
|
||||
return chunks
|
||||
|
||||
async def _cleanup(
|
||||
self,
|
||||
response_ctx: Any,
|
||||
@@ -440,3 +666,128 @@ class StreamProcessor:
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def create_smoothed_stream(
|
||||
stream_generator: AsyncGenerator[bytes, None],
|
||||
chunk_size: int = 20,
|
||||
delay_ms: int = 8,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
独立的平滑流生成函数
|
||||
|
||||
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
|
||||
|
||||
Args:
|
||||
stream_generator: 原始流生成器
|
||||
chunk_size: 每块字符数
|
||||
delay_ms: 每块之间的延迟毫秒数
|
||||
|
||||
Yields:
|
||||
平滑处理后的响应数据块
|
||||
"""
|
||||
processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
|
||||
async for chunk in processor.smooth(stream_generator):
|
||||
yield chunk
|
||||
|
||||
|
||||
class _LightweightSmoother:
|
||||
"""
|
||||
轻量级平滑处理器
|
||||
|
||||
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
|
||||
self.chunk_size = chunk_size
|
||||
self.delay_ms = delay_ms
|
||||
self._extractors: dict[str, ContentExtractor] = {}
|
||||
|
||||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||
if format_name not in self._extractors:
|
||||
extractor = get_extractor(format_name)
|
||||
if extractor:
|
||||
self._extractors[format_name] = extractor
|
||||
return self._extractors.get(format_name)
|
||||
|
||||
def _detect_format_and_extract(
|
||||
self, data: dict
|
||||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||
for format_name in get_extractor_formats():
|
||||
extractor = self._get_extractor(format_name)
|
||||
if extractor:
|
||||
content = extractor.extract_content(data)
|
||||
if content is not None:
|
||||
return content, extractor
|
||||
return None, None
|
||||
|
||||
def _calculate_delay(self) -> float:
|
||||
return self.delay_ms / 1000.0
|
||||
|
||||
def _split_content(self, content: str) -> list[str]:
|
||||
text_length = len(content)
|
||||
if text_length <= self.chunk_size:
|
||||
return [content]
|
||||
return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
|
||||
|
||||
async def smooth(
|
||||
self, stream_generator: AsyncGenerator[bytes, None]
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
buffer = b""
|
||||
is_first_content = True
|
||||
|
||||
async for chunk in stream_generator:
|
||||
buffer += chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||
event_str = event_block.decode("utf-8", errors="replace")
|
||||
|
||||
lines = event_str.strip().split("\n")
|
||||
data_str = None
|
||||
event_type = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip("\r")
|
||||
if line.startswith("event: "):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
if data_str is None:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
if data_str.strip() == "[DONE]":
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
content, extractor = self._detect_format_and_extract(data)
|
||||
|
||||
if content and len(content) > 1 and extractor:
|
||||
delay_seconds = self._calculate_delay()
|
||||
content_chunks = self._split_content(content)
|
||||
|
||||
for i, sub_content in enumerate(content_chunks):
|
||||
is_first = is_first_content and i == 0
|
||||
sse_chunk = extractor.create_chunk(
|
||||
data, sub_content, event_type=event_type, is_first=is_first
|
||||
)
|
||||
yield sse_chunk
|
||||
if i < len(content_chunks) - 1:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
is_first_content = False
|
||||
else:
|
||||
yield event_block + b"\n\n"
|
||||
if content:
|
||||
is_first_content = False
|
||||
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
@@ -4,17 +4,28 @@ Handler 基础工具函数
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取缓存创建 tokens(兼容新旧格式)
|
||||
提取缓存创建 tokens(兼容三种格式)
|
||||
|
||||
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||
根据 Anthropic API 文档,支持三种格式(按优先级):
|
||||
|
||||
此函数自动检测并适配两种格式,优先使用新格式。
|
||||
1. **嵌套格式(优先级最高)**:
|
||||
usage.cache_creation.ephemeral_5m_input_tokens
|
||||
usage.cache_creation.ephemeral_1h_input_tokens
|
||||
|
||||
2. **扁平新格式(优先级第二)**:
|
||||
usage.claude_cache_creation_5_m_tokens
|
||||
usage.claude_cache_creation_1_h_tokens
|
||||
|
||||
3. **旧格式(优先级第三)**:
|
||||
usage.cache_creation_input_tokens
|
||||
|
||||
优先使用嵌套格式,如果嵌套格式字段存在但值为 0,则智能 fallback 到旧格式。
|
||||
扁平格式和嵌套格式互斥,按顺序检查。
|
||||
|
||||
Args:
|
||||
usage: API 响应中的 usage 字典
|
||||
@@ -22,20 +33,63 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||
Returns:
|
||||
缓存创建 tokens 总数
|
||||
"""
|
||||
# 检查新格式字段是否存在(而非值是否为 0)
|
||||
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
|
||||
has_new_format = (
|
||||
# 1. 检查嵌套格式(最新格式)
|
||||
cache_creation = usage.get("cache_creation")
|
||||
if isinstance(cache_creation, dict):
|
||||
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
|
||||
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
|
||||
total = cache_5m + cache_1h
|
||||
|
||||
if total > 0:
|
||||
logger.debug(
|
||||
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||
)
|
||||
return total
|
||||
|
||||
# 嵌套格式存在但为 0,fallback 到旧格式
|
||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||
if old_format > 0:
|
||||
logger.debug(
|
||||
f"Nested cache_creation is 0, using old format: {old_format}"
|
||||
)
|
||||
return old_format
|
||||
|
||||
# 都是 0,返回 0
|
||||
return 0
|
||||
|
||||
# 2. 检查扁平新格式
|
||||
has_flat_format = (
|
||||
"claude_cache_creation_5_m_tokens" in usage
|
||||
or "claude_cache_creation_1_h_tokens" in usage
|
||||
)
|
||||
|
||||
if has_new_format:
|
||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||
return int(cache_5m) + int(cache_1h)
|
||||
if has_flat_format:
|
||||
cache_5m = int(usage.get("claude_cache_creation_5_m_tokens", 0))
|
||||
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
|
||||
total = cache_5m + cache_1h
|
||||
|
||||
# 回退到旧格式
|
||||
return int(usage.get("cache_creation_input_tokens", 0))
|
||||
if total > 0:
|
||||
logger.debug(
|
||||
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||
)
|
||||
return total
|
||||
|
||||
# 扁平格式存在但为 0,fallback 到旧格式
|
||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||
if old_format > 0:
|
||||
logger.debug(
|
||||
f"Flat cache_creation is 0, using old format: {old_format}"
|
||||
)
|
||||
return old_format
|
||||
|
||||
# 都是 0,返回 0
|
||||
return 0
|
||||
|
||||
# 3. 回退到旧格式
|
||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||
if old_format > 0:
|
||||
logger.debug(f"Using old format: cache_creation_input_tokens={old_format}")
|
||||
return old_format
|
||||
|
||||
|
||||
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
|
||||
@@ -4,8 +4,9 @@ Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
||||
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -155,6 +156,59 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"thinking_enabled": bool(request_obj.thinking),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 Claude API 支持的模型列表"""
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖认证头
|
||||
safe_headers = {
|
||||
k: v for k, v in extra_headers.items()
|
||||
if k.lower() not in ("x-api-key", "authorization", "anthropic-version")
|
||||
}
|
||||
headers.update(safe_headers)
|
||||
|
||||
# 构建 /v1/models URL
|
||||
base_url = base_url.rstrip("/")
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = cls.FORMAT_ID
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
def build_claude_adapter(x_app_header: Optional[str]):
|
||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||
|
||||
@@ -4,13 +4,15 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
|
||||
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector, ClaudeChatAdapter
|
||||
from src.config.settings import config
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
@@ -99,5 +101,30 @@ class ClaudeCliAdapter(CliAdapterBase):
|
||||
"system_present": bool(payload.get("system")),
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
||||
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
||||
cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
|
||||
if extra_headers:
|
||||
cli_headers.update(extra_headers)
|
||||
models, error = await ClaudeChatAdapter.fetch_models(
|
||||
client, base_url, api_key, cli_headers
|
||||
)
|
||||
# 更新 api_format 为 CLI 格式
|
||||
for m in models:
|
||||
m["api_format"] = cls.FORMAT_ID
|
||||
return models, error
|
||||
|
||||
|
||||
__all__ = ["ClaudeCliAdapter"]
|
||||
|
||||
@@ -4,8 +4,9 @@ Gemini Chat Adapter
|
||||
处理 Gemini API 格式的请求适配
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -151,6 +152,53 @@ class GeminiChatAdapter(ChatAdapterBase):
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 Gemini API 支持的模型列表"""
|
||||
# 兼容 base_url 已包含 /v1beta 的情况
|
||||
base_url_clean = base_url.rstrip("/")
|
||||
if base_url_clean.endswith("/v1beta"):
|
||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||
else:
|
||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "models" in data:
|
||||
# 转换为统一格式
|
||||
return [
|
||||
{
|
||||
"id": m.get("name", "").replace("models/", ""),
|
||||
"owned_by": "google",
|
||||
"display_name": m.get("displayName", ""),
|
||||
"api_format": cls.FORMAT_ID,
|
||||
}
|
||||
for m in data["models"]
|
||||
], None
|
||||
return [], None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||
"""
|
||||
|
||||
@@ -4,12 +4,15 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.api.handlers.gemini.adapter import GeminiChatAdapter
|
||||
from src.config.settings import config
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
@@ -95,6 +98,31 @@ class GeminiCliAdapter(CliAdapterBase):
|
||||
"safety_settings_count": len(payload.get("safety_settings") or []),
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
||||
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
||||
cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
|
||||
if extra_headers:
|
||||
cli_headers.update(extra_headers)
|
||||
models, error = await GeminiChatAdapter.fetch_models(
|
||||
client, base_url, api_key, cli_headers
|
||||
)
|
||||
# 更新 api_format 为 CLI 格式
|
||||
for m in models:
|
||||
m["api_format"] = cls.FORMAT_ID
|
||||
return models, error
|
||||
|
||||
|
||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||
"""
|
||||
|
||||
@@ -4,8 +4,9 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -105,5 +106,53 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 OpenAI 兼容 API 支持的模型列表"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖 Authorization
|
||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||
headers.update(safe_headers)
|
||||
|
||||
# 构建 /v1/models URL
|
||||
base_url = base_url.rstrip("/")
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = cls.FORMAT_ID
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
__all__ = ["OpenAIChatAdapter"]
|
||||
|
||||
@@ -4,12 +4,15 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.api.handlers.openai.adapter import OpenAIChatAdapter
|
||||
from src.config.settings import config
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
@@ -40,5 +43,30 @@ class OpenAICliAdapter(CliAdapterBase):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
async def fetch_models(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[list, Optional[str]]:
|
||||
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
||||
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
||||
cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
|
||||
if extra_headers:
|
||||
cli_headers.update(extra_headers)
|
||||
models, error = await OpenAIChatAdapter.fetch_models(
|
||||
client, base_url, api_key, cli_headers
|
||||
)
|
||||
# 更新 api_format 为 CLI 格式
|
||||
for m in models:
|
||||
m["api_format"] = cls.FORMAT_ID
|
||||
return models, error
|
||||
|
||||
|
||||
__all__ = ["OpenAICliAdapter"]
|
||||
|
||||
@@ -77,7 +77,10 @@ class ConcurrencyDefaults:
|
||||
MAX_CONCURRENT_LIMIT = 200
|
||||
|
||||
# 最小并发限制下限
|
||||
MIN_CONCURRENT_LIMIT = 1
|
||||
# 设置为 3 而不是 1,因为预留机制(10%预留给缓存用户)会导致
|
||||
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0,永远无法命中
|
||||
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0),这是可接受的
|
||||
MIN_CONCURRENT_LIMIT = 3
|
||||
|
||||
# === 探测性扩容参数 ===
|
||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||
|
||||
@@ -56,10 +56,11 @@ class Config:
|
||||
|
||||
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
||||
redis_required_env = os.getenv("REDIS_REQUIRED")
|
||||
if redis_required_env is None:
|
||||
self.require_redis = self.environment not in {"development", "test", "testing"}
|
||||
else:
|
||||
if redis_required_env is not None:
|
||||
self.require_redis = redis_required_env.lower() == "true"
|
||||
else:
|
||||
# 保持向后兼容:开发环境可选,生产环境必需
|
||||
self.require_redis = self.environment not in {"development", "test", "testing"}
|
||||
|
||||
# CORS配置 - 使用环境变量配置允许的源
|
||||
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
||||
@@ -133,6 +134,18 @@ class Config:
|
||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||
|
||||
# 限流降级策略配置
|
||||
# RATE_LIMIT_FAIL_OPEN: 当限流服务(Redis)异常时的行为
|
||||
#
|
||||
# True (默认): fail-open - 放行请求(优先可用性)
|
||||
# 风险:Redis 故障期间无法限流,可能被滥用
|
||||
# 适用:API 网关作为关键基础设施,必须保持高可用
|
||||
#
|
||||
# False: fail-close - 拒绝所有请求(优先安全性)
|
||||
# 风险:Redis 故障会导致 API 网关不可用
|
||||
# 适用:有严格速率限制要求的安全敏感场景
|
||||
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
|
||||
|
||||
# HTTP 请求超时配置(秒)
|
||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||
@@ -141,8 +154,23 @@ class Config:
|
||||
# 流式处理配置
|
||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
|
||||
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
|
||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||
|
||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||
self.internal_user_agent_claude_cli = os.getenv(
|
||||
"CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
|
||||
)
|
||||
self.internal_user_agent_openai_cli = os.getenv(
|
||||
"OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
|
||||
)
|
||||
self.internal_user_agent_gemini_cli = os.getenv(
|
||||
"GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
|
||||
)
|
||||
|
||||
# 验证连接池配置
|
||||
self._validate_pool_config()
|
||||
@@ -165,6 +193,39 @@ class Config:
|
||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||
return self.db_pool_size
|
||||
|
||||
def _parse_ttfb_timeout(self) -> float:
|
||||
"""
|
||||
解析 TTFB 超时配置,带错误处理和范围限制
|
||||
|
||||
TTFB (Time To First Byte) 用于检测慢响应的 Provider,超时触发故障转移。
|
||||
此值必须小于 http_write_timeout,避免竞态条件。
|
||||
|
||||
Returns:
|
||||
超时时间(秒),范围 10-120,默认 30
|
||||
"""
|
||||
default_timeout = 30.0
|
||||
min_timeout = 10.0
|
||||
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
|
||||
|
||||
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
|
||||
try:
|
||||
timeout = float(raw_value)
|
||||
except ValueError:
|
||||
# 延迟导入,避免循环依赖(Config 初始化时 logger 可能未就绪)
|
||||
self._ttfb_config_warning = (
|
||||
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}秒"
|
||||
)
|
||||
return default_timeout
|
||||
|
||||
# 范围限制
|
||||
clamped = max(min_timeout, min(max_timeout, timeout))
|
||||
if clamped != timeout:
|
||||
self._ttfb_config_warning = (
|
||||
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}],"
|
||||
f"已调整为 {clamped}秒"
|
||||
)
|
||||
return clamped
|
||||
|
||||
def _validate_pool_config(self) -> None:
|
||||
"""验证连接池配置是否安全"""
|
||||
total_per_worker = self.db_pool_size + self.db_max_overflow
|
||||
@@ -212,6 +273,10 @@ class Config:
|
||||
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
||||
logger.warning(self._pool_config_warning)
|
||||
|
||||
# TTFB 超时配置警告
|
||||
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
|
||||
logger.warning(self._ttfb_config_warning)
|
||||
|
||||
# 管理员密码检查(必须在环境变量中设置)
|
||||
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
||||
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
||||
|
||||
@@ -10,8 +10,8 @@ class APIFormat(Enum):
|
||||
"""API格式枚举 - 决定请求/响应的处理方式"""
|
||||
|
||||
CLAUDE = "CLAUDE" # Claude API 格式
|
||||
OPENAI = "OPENAI" # OpenAI API 格式
|
||||
CLAUDE_CLI = "CLAUDE_CLI" # Claude CLI API 格式(使用 authorization: Bearer)
|
||||
OPENAI = "OPENAI" # OpenAI API 格式
|
||||
OPENAI_CLI = "OPENAI_CLI" # OpenAI CLI/Responses API 格式(用于 Claude Code 等客户端)
|
||||
GEMINI = "GEMINI" # Google Gemini API 格式
|
||||
GEMINI_CLI = "GEMINI_CLI" # Gemini CLI API 格式
|
||||
|
||||
@@ -188,12 +188,16 @@ class ProviderNotAvailableException(ProviderException):
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
upstream_status: Optional[int] = None,
|
||||
upstream_response: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
self.upstream_status = upstream_status
|
||||
self.upstream_response = upstream_response
|
||||
|
||||
|
||||
class ProviderTimeoutException(ProviderException):
|
||||
@@ -442,6 +446,36 @@ class EmbeddedErrorException(ProviderException):
|
||||
self.error_status = error_status
|
||||
|
||||
|
||||
class ProviderCompatibilityException(ProviderException):
|
||||
"""Provider 兼容性错误异常 - 应该触发故障转移
|
||||
|
||||
用于处理因 Provider 不支持某些参数或功能导致的错误。
|
||||
这类错误不是用户请求本身的问题,换一个 Provider 可能就能成功,应该触发故障转移。
|
||||
|
||||
常见场景:
|
||||
- Unsupported parameter(不支持的参数)
|
||||
- Unsupported model(不支持的模型)
|
||||
- Unsupported feature(不支持的功能)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider_name: Optional[str] = None,
|
||||
status_code: int = 400,
|
||||
upstream_error: Optional[str] = None,
|
||||
request_metadata: Optional[Any] = None,
|
||||
):
|
||||
self.upstream_error = upstream_error
|
||||
super().__init__(
|
||||
message=message,
|
||||
provider_name=provider_name,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
# 覆盖状态码为 400(保持与上游一致)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class UpstreamClientException(ProxyException):
|
||||
"""上游返回的客户端错误异常 - HTTP 4xx 错误,不应该重试
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ model_mapping_resolution_total = Counter(
|
||||
"model_mapping_resolution_total",
|
||||
"Total number of model mapping resolutions",
|
||||
["method", "cache_hit"],
|
||||
# method: direct_match, provider_model_name, alias, not_found
|
||||
# method: direct_match, provider_model_name, mapping, not_found
|
||||
# cache_hit: true, false
|
||||
)
|
||||
|
||||
|
||||
32
src/main.py
32
src/main.py
@@ -4,13 +4,10 @@
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from src.api.admin import router as admin_router
|
||||
from src.api.announcements import router as announcement_router
|
||||
@@ -299,33 +296,6 @@ app.include_router(dashboard_router) # 仪表盘端点
|
||||
app.include_router(public_router) # 公开API端点(用户可查看提供商和模型)
|
||||
app.include_router(monitoring_router) # 监控端点
|
||||
|
||||
# 静态文件服务(前端构建产物)
|
||||
# 检查前端构建目录是否存在
|
||||
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
|
||||
if frontend_dist.exists():
|
||||
# 挂载静态资源目录
|
||||
app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets")
|
||||
|
||||
# SPA catch-all路由 - 必须放在最后
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_spa(request: Request, full_path: str):
|
||||
"""
|
||||
处理所有未匹配的GET请求,返回index.html供前端路由处理
|
||||
仅对非API路径生效
|
||||
"""
|
||||
# 如果是API路径,不处理
|
||||
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
# 返回index.html,让前端路由处理
|
||||
index_file = frontend_dist / "index.html"
|
||||
if index_file.exists():
|
||||
return FileResponse(str(index_file))
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Frontend not built")
|
||||
|
||||
else:
|
||||
logger.warning("前端构建目录不存在,前端路由将无法使用")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -336,10 +336,44 @@ class PluginMiddleware:
|
||||
)
|
||||
return result
|
||||
return None
|
||||
except ConnectionError as e:
|
||||
# Redis 连接错误:根据配置决定
|
||||
logger.warning(f"Rate limit connection error: {e}")
|
||||
if config.rate_limit_fail_open:
|
||||
return None
|
||||
else:
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
retry_after=30,
|
||||
message="Rate limit service unavailable"
|
||||
)
|
||||
except TimeoutError as e:
|
||||
# 超时错误:可能是负载过高,根据配置决定
|
||||
logger.warning(f"Rate limit timeout: {e}")
|
||||
if config.rate_limit_fail_open:
|
||||
return None
|
||||
else:
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
retry_after=30,
|
||||
message="Rate limit service timeout"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit error: {e}")
|
||||
# 发生错误时允许请求通过
|
||||
return None
|
||||
logger.error(f"Rate limit error: {type(e).__name__}: {e}")
|
||||
# 其他异常:根据配置决定
|
||||
if config.rate_limit_fail_open:
|
||||
# fail-open: 异常时放行请求(优先可用性)
|
||||
return None
|
||||
else:
|
||||
# fail-close: 异常时拒绝请求(优先安全性)
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
retry_after=60,
|
||||
message="Rate limit service error"
|
||||
)
|
||||
|
||||
async def _call_pre_request_plugins(self, request: Request) -> None:
|
||||
"""调用请求前的插件(当前保留扩展点)"""
|
||||
|
||||
@@ -346,9 +346,9 @@ class ModelCreate(BaseModel):
|
||||
provider_model_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Provider 侧的主模型名称"
|
||||
)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
provider_model_mappings: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
description="模型名称映射列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
||||
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
@@ -376,9 +376,9 @@ class ModelUpdate(BaseModel):
|
||||
"""更新模型请求"""
|
||||
|
||||
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
provider_model_mappings: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
description="模型名称映射列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: Optional[str] = None
|
||||
# 按次计费配置
|
||||
@@ -404,7 +404,7 @@ class ModelResponse(BaseModel):
|
||||
provider_id: str
|
||||
global_model_id: Optional[str]
|
||||
provider_model_name: str
|
||||
provider_model_aliases: Optional[List[dict]] = None
|
||||
provider_model_mappings: Optional[List[dict]] = None
|
||||
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = None
|
||||
|
||||
@@ -671,10 +671,10 @@ class Model(Base):
|
||||
|
||||
# Provider 映射配置
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
||||
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
|
||||
# 模型名称映射列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
|
||||
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
|
||||
# 为空时只使用 provider_model_name
|
||||
provider_model_aliases = Column(JSON, nullable=True, default=None)
|
||||
provider_model_mappings = Column(JSON, nullable=True, default=None)
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
|
||||
price_per_request = Column(Float, nullable=True) # 每次请求固定费用
|
||||
@@ -820,25 +820,25 @@ class Model(Base):
|
||||
) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
|
||||
如果配置了 provider_model_mappings,按优先级选择(数字越小越优先);
|
||||
相同优先级的映射通过哈希分散实现负载均衡(与 Key 调度策略一致);
|
||||
否则返回 provider_model_name。
|
||||
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一映射
|
||||
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的映射
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
if not self.provider_model_aliases:
|
||||
if not self.provider_model_mappings:
|
||||
return self.provider_model_name
|
||||
|
||||
raw_aliases = self.provider_model_aliases
|
||||
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
|
||||
raw_mappings = self.provider_model_mappings
|
||||
if not isinstance(raw_mappings, list) or len(raw_mappings) == 0:
|
||||
return self.provider_model_name
|
||||
|
||||
aliases: list[dict] = []
|
||||
for raw in raw_aliases:
|
||||
mappings: list[dict] = []
|
||||
for raw in raw_mappings:
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
name = raw.get("name")
|
||||
@@ -846,10 +846,10 @@ class Model(Base):
|
||||
continue
|
||||
|
||||
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||
alias_api_formats = raw.get("api_formats")
|
||||
if api_format and alias_api_formats:
|
||||
mapping_api_formats = raw.get("api_formats")
|
||||
if api_format and mapping_api_formats:
|
||||
# 如果配置了作用域,只有匹配时才生效
|
||||
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||
if isinstance(mapping_api_formats, list) and api_format not in mapping_api_formats:
|
||||
continue
|
||||
|
||||
raw_priority = raw.get("priority", 1)
|
||||
@@ -860,47 +860,47 @@ class Model(Base):
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
aliases.append({"name": name.strip(), "priority": priority})
|
||||
mappings.append({"name": name.strip(), "priority": priority})
|
||||
|
||||
if not aliases:
|
||||
if not mappings:
|
||||
return self.provider_model_name
|
||||
|
||||
# 按优先级排序(数字越小越优先)
|
||||
sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
|
||||
sorted_mappings = sorted(mappings, key=lambda x: x["priority"])
|
||||
|
||||
# 获取最高优先级(最小数字)
|
||||
highest_priority = sorted_aliases[0]["priority"]
|
||||
highest_priority = sorted_mappings[0]["priority"]
|
||||
|
||||
# 获取所有最高优先级的别名
|
||||
top_priority_aliases = [
|
||||
alias for alias in sorted_aliases
|
||||
if alias["priority"] == highest_priority
|
||||
# 获取所有最高优先级的映射
|
||||
top_priority_mappings = [
|
||||
mapping for mapping in sorted_mappings
|
||||
if mapping["priority"] == highest_priority
|
||||
]
|
||||
|
||||
# 如果有多个相同优先级的别名,通过哈希分散选择
|
||||
if len(top_priority_aliases) > 1 and affinity_key:
|
||||
# 为每个别名计算哈希得分,选择得分最小的
|
||||
def hash_score(alias: dict) -> int:
|
||||
combined = f"{affinity_key}:{alias['name']}"
|
||||
# 如果有多个相同优先级的映射,通过哈希分散选择
|
||||
if len(top_priority_mappings) > 1 and affinity_key:
|
||||
# 为每个映射计算哈希得分,选择得分最小的
|
||||
def hash_score(mapping: dict) -> int:
|
||||
combined = f"{affinity_key}:{mapping['name']}"
|
||||
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
|
||||
|
||||
selected = min(top_priority_aliases, key=hash_score)
|
||||
elif len(top_priority_aliases) > 1:
|
||||
selected = min(top_priority_mappings, key=hash_score)
|
||||
elif len(top_priority_mappings) > 1:
|
||||
# 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
|
||||
# 避免随机选择导致同一请求重试时选择不同的模型名称
|
||||
selected = min(top_priority_aliases, key=lambda x: x["name"])
|
||||
selected = min(top_priority_mappings, key=lambda x: x["name"])
|
||||
else:
|
||||
selected = top_priority_aliases[0]
|
||||
selected = top_priority_mappings[0]
|
||||
|
||||
return selected["name"]
|
||||
|
||||
def get_all_provider_model_names(self) -> list[str]:
|
||||
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
|
||||
"""获取所有可用的 Provider 模型名称(主名称 + 映射名称)"""
|
||||
names = [self.provider_model_name]
|
||||
if self.provider_model_aliases:
|
||||
for alias in self.provider_model_aliases:
|
||||
if isinstance(alias, dict) and alias.get("name"):
|
||||
names.append(alias["name"])
|
||||
if self.provider_model_mappings:
|
||||
for mapping in self.provider_model_mappings:
|
||||
if isinstance(mapping, dict) and mapping.get("name"):
|
||||
names.append(mapping["name"])
|
||||
return names
|
||||
|
||||
|
||||
@@ -1308,6 +1308,53 @@ class StatsDaily(Base):
|
||||
)
|
||||
|
||||
|
||||
class StatsDailyModel(Base):
|
||||
"""每日模型统计快照 - 用于快速查询每日模型维度数据"""
|
||||
|
||||
__tablename__ = "stats_daily_model"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
|
||||
# 统计日期 (UTC)
|
||||
date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# 模型名称
|
||||
model = Column(String(100), nullable=False)
|
||||
|
||||
# 请求统计
|
||||
total_requests = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Token 统计
|
||||
input_tokens = Column(BigInteger, default=0, nullable=False)
|
||||
output_tokens = Column(BigInteger, default=0, nullable=False)
|
||||
cache_creation_tokens = Column(BigInteger, default=0, nullable=False)
|
||||
cache_read_tokens = Column(BigInteger, default=0, nullable=False)
|
||||
|
||||
# 成本统计 (USD)
|
||||
total_cost = Column(Float, default=0.0, nullable=False)
|
||||
|
||||
# 性能统计
|
||||
avg_response_time_ms = Column(Float, default=0.0, nullable=False)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 唯一约束:每个模型每天只有一条记录
|
||||
__table_args__ = (
|
||||
UniqueConstraint("date", "model", name="uq_stats_daily_model"),
|
||||
Index("idx_stats_daily_model_date", "date"),
|
||||
Index("idx_stats_daily_model_date_model", "date", "model"),
|
||||
)
|
||||
|
||||
|
||||
class StatsSummary(Base):
|
||||
"""全局统计汇总 - 单行记录,存储截止到昨天的累计数据"""
|
||||
|
||||
|
||||
@@ -226,8 +226,11 @@ class EndpointAPIKeyUpdate(BaseModel):
|
||||
global_priority: Optional[int] = Field(
|
||||
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
||||
)
|
||||
# 注意:max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
||||
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null(自适应模式)"
|
||||
# - 不提供字段:不更新
|
||||
# - 提供 null:切换为自适应模式
|
||||
# - 提供数字:设置固定并发限制
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数(null=自适应模式)")
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
||||
|
||||
@@ -301,6 +301,36 @@ class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class ImportFromUpstreamRequest(BaseModel):
|
||||
"""从上游提供商导入模型请求"""
|
||||
|
||||
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
|
||||
|
||||
|
||||
class ImportFromUpstreamSuccessItem(BaseModel):
|
||||
"""导入成功的模型信息"""
|
||||
|
||||
model_id: str = Field(..., description="上游模型 ID")
|
||||
global_model_id: str = Field(..., description="GlobalModel ID")
|
||||
global_model_name: str = Field(..., description="GlobalModel 名称")
|
||||
provider_model_id: str = Field(..., description="Provider Model ID")
|
||||
created_global_model: bool = Field(..., description="是否新创建了 GlobalModel")
|
||||
|
||||
|
||||
class ImportFromUpstreamErrorItem(BaseModel):
|
||||
"""导入失败的模型信息"""
|
||||
|
||||
model_id: str = Field(..., description="上游模型 ID")
|
||||
error: str = Field(..., description="错误信息")
|
||||
|
||||
|
||||
class ImportFromUpstreamResponse(BaseModel):
|
||||
"""从上游提供商导入模型响应"""
|
||||
|
||||
success: List[ImportFromUpstreamSuccessItem]
|
||||
errors: List[ImportFromUpstreamErrorItem]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
@@ -311,6 +341,10 @@ __all__ = [
|
||||
"GlobalModelResponse",
|
||||
"GlobalModelUpdate",
|
||||
"GlobalModelWithStats",
|
||||
"ImportFromUpstreamErrorItem",
|
||||
"ImportFromUpstreamRequest",
|
||||
"ImportFromUpstreamResponse",
|
||||
"ImportFromUpstreamSuccessItem",
|
||||
"ModelCapabilities",
|
||||
"ModelCatalogItem",
|
||||
"ModelCatalogProviderDetail",
|
||||
|
||||
@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
|
||||
if config.environment == "production":
|
||||
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
||||
config.jwt_secret_key = secrets.token_urlsafe(32)
|
||||
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
|
||||
logger.warning("JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发")
|
||||
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
||||
|
||||
JWT_SECRET_KEY = config.jwt_secret_key
|
||||
|
||||
12
src/services/cache/aware_scheduler.py
vendored
12
src/services/cache/aware_scheduler.py
vendored
@@ -589,14 +589,14 @@ class CacheAwareScheduler:
|
||||
|
||||
target_format = normalize_api_format(api_format)
|
||||
|
||||
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService)
|
||||
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和映射名匹配,使用 ModelCacheService)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
logger.warning(f"GlobalModel not found: {model_name}")
|
||||
raise ModelNotSupportedException(model=model_name)
|
||||
|
||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保映射名和规范名都能命中同一个缓存
|
||||
global_model_id: str = str(global_model.id)
|
||||
requested_model_name = model_name
|
||||
resolved_model_name = str(global_model.name)
|
||||
@@ -751,19 +751,19 @@ class CacheAwareScheduler:
|
||||
|
||||
支持两种匹配方式:
|
||||
1. 直接匹配 GlobalModel.name
|
||||
2. 通过 ModelCacheService 匹配别名(全局查找)
|
||||
2. 通过 ModelCacheService 匹配映射名(全局查找)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: Provider 对象
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或映射名)
|
||||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||||
|
||||
Returns:
|
||||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||||
"""
|
||||
# 使用 ModelCacheService 解析模型名称(支持别名)
|
||||
# 使用 ModelCacheService 解析模型名称(支持映射名)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
@@ -914,7 +914,7 @@ class CacheAwareScheduler:
|
||||
db: 数据库会话
|
||||
providers: Provider 列表
|
||||
target_format: 目标 API 格式
|
||||
model_name: 模型名称(用户请求的名称,可能是别名)
|
||||
model_name: 模型名称(用户请求的名称,可能是映射名)
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||||
max_candidates: 最大候选数
|
||||
|
||||
32
src/services/cache/model_cache.py
vendored
32
src/services/cache/model_cache.py
vendored
@@ -198,7 +198,7 @@ class ModelCacheService:
|
||||
provider_id: Optional[str] = None,
|
||||
global_model_id: Optional[str] = None,
|
||||
provider_model_name: Optional[str] = None,
|
||||
provider_model_aliases: Optional[list] = None,
|
||||
provider_model_mappings: Optional[list] = None,
|
||||
) -> None:
|
||||
"""清除 Model 缓存
|
||||
|
||||
@@ -207,7 +207,7 @@ class ModelCacheService:
|
||||
provider_id: Provider ID(用于清除 provider_global 缓存)
|
||||
global_model_id: GlobalModel ID(用于清除 provider_global 缓存)
|
||||
provider_model_name: provider_model_name(用于清除 resolve 缓存)
|
||||
provider_model_aliases: 映射名称列表(用于清除 resolve 缓存)
|
||||
provider_model_mappings: 映射名称列表(用于清除 resolve 缓存)
|
||||
"""
|
||||
# 清除 model:id 缓存
|
||||
await CacheService.delete(f"model:id:{model_id}")
|
||||
@@ -222,16 +222,16 @@ class ModelCacheService:
|
||||
else:
|
||||
logger.debug(f"Model 缓存已清除: {model_id}")
|
||||
|
||||
# 清除 resolve 缓存(provider_model_name 和 aliases 可能都被用作解析 key)
|
||||
# 清除 resolve 缓存(provider_model_name 和 mappings 可能都被用作解析 key)
|
||||
resolve_keys_to_clear = []
|
||||
if provider_model_name:
|
||||
resolve_keys_to_clear.append(provider_model_name)
|
||||
if provider_model_aliases:
|
||||
for alias_entry in provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name:
|
||||
resolve_keys_to_clear.append(alias_name)
|
||||
if provider_model_mappings:
|
||||
for mapping_entry in provider_model_mappings:
|
||||
if isinstance(mapping_entry, dict):
|
||||
mapping_name = mapping_entry.get("name", "").strip()
|
||||
if mapping_name:
|
||||
resolve_keys_to_clear.append(mapping_name)
|
||||
|
||||
for key in resolve_keys_to_clear:
|
||||
await CacheService.delete(f"global_model:resolve:{key}")
|
||||
@@ -261,8 +261,8 @@ class ModelCacheService:
|
||||
2. 通过 provider_model_name 匹配(查询 Model 表)
|
||||
3. 直接匹配 GlobalModel.name(兜底)
|
||||
|
||||
注意:此方法不使用 provider_model_aliases 进行全局解析。
|
||||
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
|
||||
注意:此方法不使用 provider_model_mappings 进行全局解析。
|
||||
provider_model_mappings 是 Provider 级别的映射配置,只在特定 Provider 上下文中生效,
|
||||
由 resolve_provider_model() 处理。
|
||||
|
||||
Args:
|
||||
@@ -301,9 +301,9 @@ class ModelCacheService:
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases)
|
||||
# 重要:provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
|
||||
# 全局解析不应该受到某个 Provider 别名配置的影响
|
||||
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_mappings)
|
||||
# 重要:provider_model_mappings 是 Provider 级别的映射配置,只在特定 Provider 上下文中生效
|
||||
# 全局解析不应该受到某个 Provider 映射配置的影响
|
||||
# 例如:Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
|
||||
from src.models.database import Provider
|
||||
|
||||
@@ -401,7 +401,7 @@ class ModelCacheService:
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
|
||||
"provider_model_mappings": getattr(model, "provider_model_mappings", None),
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
@@ -424,7 +424,7 @@ class ModelCacheService:
|
||||
provider_id=model_dict["provider_id"],
|
||||
global_model_id=model_dict["global_model_id"],
|
||||
provider_model_name=model_dict["provider_model_name"],
|
||||
provider_model_aliases=model_dict.get("provider_model_aliases"),
|
||||
provider_model_mappings=model_dict.get("provider_model_mappings"),
|
||||
is_active=model_dict["is_active"],
|
||||
is_available=model_dict.get("is_available", True),
|
||||
price_per_request=model_dict.get("price_per_request"),
|
||||
|
||||
@@ -443,7 +443,7 @@ class ModelCostService:
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或映射名)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
|
||||
@@ -84,11 +84,11 @@ class ModelMapperMiddleware:
|
||||
获取模型映射
|
||||
|
||||
简化后的逻辑:
|
||||
1. 通过 GlobalModel.name 或别名解析 GlobalModel
|
||||
1. 通过 GlobalModel.name 或映射名解析 GlobalModel
|
||||
2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名(可以是 GlobalModel.name 或别名)
|
||||
source_model: 用户请求的模型名(可以是 GlobalModel.name 或映射名)
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
@@ -101,7 +101,7 @@ class ModelMapperMiddleware:
|
||||
|
||||
mapping = None
|
||||
|
||||
# 步骤 1: 解析 GlobalModel(支持别名)
|
||||
# 步骤 1: 解析 GlobalModel(支持映射名)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(
|
||||
self.db, source_model
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
|
||||
from src.models.database import Model, Provider
|
||||
from src.api.base.models_service import invalidate_models_list_cache
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
|
||||
@@ -50,7 +51,7 @@ class ModelService:
|
||||
provider_id=provider_id,
|
||||
global_model_id=model_data.global_model_id,
|
||||
provider_model_name=model_data.provider_model_name,
|
||||
provider_model_aliases=model_data.provider_model_aliases,
|
||||
provider_model_mappings=model_data.provider_model_mappings,
|
||||
price_per_request=model_data.price_per_request,
|
||||
tiered_pricing=model_data.tiered_pricing,
|
||||
supports_vision=model_data.supports_vision,
|
||||
@@ -75,6 +76,10 @@ class ModelService:
|
||||
)
|
||||
|
||||
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
asyncio.create_task(invalidate_models_list_cache())
|
||||
|
||||
return model
|
||||
|
||||
except IntegrityError as e:
|
||||
@@ -148,9 +153,9 @@ class ModelService:
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
|
||||
# 保存旧的别名,用于清除缓存
|
||||
# 保存旧的映射,用于清除缓存
|
||||
old_provider_model_name = model.provider_model_name
|
||||
old_provider_model_aliases = model.provider_model_aliases
|
||||
old_provider_model_mappings = model.provider_model_mappings
|
||||
|
||||
# 更新字段
|
||||
update_data = model_data.model_dump(exclude_unset=True)
|
||||
@@ -169,26 +174,26 @@ class ModelService:
|
||||
db.refresh(model)
|
||||
|
||||
# 清除 Redis 缓存(异步执行,不阻塞返回)
|
||||
# 先清除旧的别名缓存
|
||||
# 先清除旧的映射缓存
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=old_provider_model_name,
|
||||
provider_model_aliases=old_provider_model_aliases,
|
||||
provider_model_mappings=old_provider_model_mappings,
|
||||
)
|
||||
)
|
||||
# 再清除新的别名缓存(如果有变化)
|
||||
# 再清除新的映射缓存(如果有变化)
|
||||
if (model.provider_model_name != old_provider_model_name or
|
||||
model.provider_model_aliases != old_provider_model_aliases):
|
||||
model.provider_model_mappings != old_provider_model_mappings):
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
provider_model_mappings=model.provider_model_mappings,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -197,6 +202,9 @@ class ModelService:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
asyncio.create_task(invalidate_models_list_cache())
|
||||
|
||||
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||||
return model
|
||||
except IntegrityError as e:
|
||||
@@ -238,7 +246,7 @@ class ModelService:
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": model.provider_model_aliases,
|
||||
"provider_model_mappings": model.provider_model_mappings,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -252,7 +260,7 @@ class ModelService:
|
||||
provider_id=cache_info["provider_id"],
|
||||
global_model_id=cache_info["global_model_id"],
|
||||
provider_model_name=cache_info["provider_model_name"],
|
||||
provider_model_aliases=cache_info["provider_model_aliases"],
|
||||
provider_model_mappings=cache_info["provider_model_mappings"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -261,6 +269,9 @@ class ModelService:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
asyncio.create_task(invalidate_models_list_cache())
|
||||
|
||||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
|
||||
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
|
||||
except Exception as e:
|
||||
@@ -286,7 +297,7 @@ class ModelService:
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
provider_model_mappings=model.provider_model_mappings,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -295,6 +306,9 @@ class ModelService:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
asyncio.create_task(invalidate_models_list_cache())
|
||||
|
||||
status = "可用" if is_available else "不可用"
|
||||
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
|
||||
return model
|
||||
@@ -358,6 +372,9 @@ class ModelService:
|
||||
for model in created_models:
|
||||
db.refresh(model)
|
||||
logger.info(f"批量创建 {len(created_models)} 个模型成功")
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
asyncio.create_task(invalidate_models_list_cache())
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
logger.error(f"批量创建模型失败: {str(e)}")
|
||||
@@ -373,7 +390,7 @@ class ModelService:
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
provider_model_mappings=model.provider_model_mappings,
|
||||
# 原始配置值(可能为空)
|
||||
price_per_request=model.price_per_request,
|
||||
tiered_pricing=model.tiered_pricing,
|
||||
|
||||
@@ -15,6 +15,7 @@ from src.core.enums import APIFormat
|
||||
from src.core.exceptions import (
|
||||
ConcurrencyLimitError,
|
||||
ProviderAuthException,
|
||||
ProviderCompatibilityException,
|
||||
ProviderException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
@@ -81,7 +82,9 @@ class ErrorClassifier:
|
||||
"context_length_exceeded", # 上下文长度超限
|
||||
"content_length_limit", # 请求内容长度超限 (Claude API)
|
||||
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
|
||||
"max_tokens", # token 数超限
|
||||
# 注意:移除了 "max_tokens",因为 max_tokens 相关错误可能是 Provider 兼容性问题
|
||||
# 如 "Unsupported parameter: 'max_tokens' is not supported with this model"
|
||||
# 这类错误应由 COMPATIBILITY_ERROR_PATTERNS 处理
|
||||
"invalid_prompt", # 无效的提示词
|
||||
"content too long", # 内容过长
|
||||
"input is too long", # 输入过长 (AWS)
|
||||
@@ -136,6 +139,19 @@ class ErrorClassifier:
|
||||
"CONTENT_POLICY_VIOLATION",
|
||||
)
|
||||
|
||||
# Provider 兼容性错误模式 - 这类错误应该触发故障转移
|
||||
# 因为换一个 Provider 可能就能成功
|
||||
COMPATIBILITY_ERROR_PATTERNS: Tuple[str, ...] = (
|
||||
"unsupported parameter", # 不支持的参数
|
||||
"unsupported model", # 不支持的模型
|
||||
"unsupported feature", # 不支持的功能
|
||||
"not supported with this model", # 此模型不支持
|
||||
"model does not support", # 模型不支持
|
||||
"parameter is not supported", # 参数不支持
|
||||
"feature is not supported", # 功能不支持
|
||||
"not available for this model", # 此模型不可用
|
||||
)
|
||||
|
||||
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析错误响应为结构化数据
|
||||
@@ -261,6 +277,25 @@ class ErrorClassifier:
|
||||
search_text = f"{parsed['message']} {parsed['raw']}".lower()
|
||||
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||
|
||||
def _is_compatibility_error(self, error_text: Optional[str]) -> bool:
|
||||
"""
|
||||
检测错误响应是否为 Provider 兼容性错误(应触发故障转移)
|
||||
|
||||
这类错误是因为 Provider 不支持某些参数或功能导致的,
|
||||
换一个 Provider 可能就能成功。
|
||||
|
||||
Args:
|
||||
error_text: 错误响应文本
|
||||
|
||||
Returns:
|
||||
是否为兼容性错误
|
||||
"""
|
||||
if not error_text:
|
||||
return False
|
||||
|
||||
search_text = error_text.lower()
|
||||
return any(pattern.lower() in search_text for pattern in self.COMPATIBILITY_ERROR_PATTERNS)
|
||||
|
||||
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
从错误响应中提取错误消息
|
||||
@@ -425,6 +460,16 @@ class ErrorClassifier:
|
||||
),
|
||||
)
|
||||
|
||||
# 400 错误:先检查是否为 Provider 兼容性错误(应触发故障转移)
|
||||
if status == 400 and self._is_compatibility_error(error_response_text):
|
||||
logger.info(f"检测到 Provider 兼容性错误,将触发故障转移: {extracted_message}")
|
||||
return ProviderCompatibilityException(
|
||||
message=extracted_message or "Provider 不支持此请求",
|
||||
provider_name=provider_name,
|
||||
status_code=400,
|
||||
upstream_error=error_response_text,
|
||||
)
|
||||
|
||||
# 400 错误:检查是否为客户端请求错误(不应重试)
|
||||
if status == 400 and self._is_client_error(error_response_text):
|
||||
logger.info(f"检测到客户端请求错误,不进行重试: {extracted_message}")
|
||||
|
||||
@@ -427,6 +427,9 @@ class FallbackOrchestrator:
|
||||
)
|
||||
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
|
||||
error_msg = str(cause) or repr(cause)
|
||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=self.db,
|
||||
candidate_id=candidate_record_id,
|
||||
@@ -439,6 +442,9 @@ class FallbackOrchestrator:
|
||||
|
||||
# 未知错误:记录失败并抛出
|
||||
error_msg = str(cause) or repr(cause)
|
||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=self.db,
|
||||
candidate_id=candidate_record_id,
|
||||
|
||||
@@ -289,11 +289,17 @@ class RequestResult:
|
||||
status_code = 500
|
||||
error_type = "internal_error"
|
||||
|
||||
# 构建错误消息,包含上游响应信息
|
||||
error_message = str(exception)
|
||||
if isinstance(exception, ProviderNotAvailableException):
|
||||
if exception.upstream_response:
|
||||
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
|
||||
|
||||
return cls(
|
||||
status=RequestStatus.FAILED,
|
||||
metadata=metadata,
|
||||
status_code=status_code,
|
||||
error_message=str(exception),
|
||||
error_message=error_message,
|
||||
error_type=error_type,
|
||||
response_time_ms=response_time_ms,
|
||||
is_stream=is_stream,
|
||||
|
||||
@@ -259,6 +259,9 @@ class CleanupScheduler:
|
||||
StatsAggregatorService.aggregate_daily_stats(
|
||||
db, current_date_local
|
||||
)
|
||||
StatsAggregatorService.aggregate_daily_model_stats(
|
||||
db, current_date_local
|
||||
)
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
@@ -291,6 +294,7 @@ class CleanupScheduler:
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||
StatsAggregatorService.aggregate_daily_model_stats(db, yesterday_local)
|
||||
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
|
||||
@@ -12,7 +12,6 @@ from src.core.logger import logger
|
||||
from src.models.database import Provider, SystemConfig
|
||||
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
"""日志记录级别"""
|
||||
|
||||
@@ -94,6 +93,35 @@ class SystemConfigService:
|
||||
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def get_configs(cls, db: Session, keys: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
批量获取系统配置值
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
keys: 配置键列表
|
||||
|
||||
Returns:
|
||||
配置键值字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 一次查询获取所有配置
|
||||
configs = db.query(SystemConfig).filter(SystemConfig.key.in_(keys)).all()
|
||||
config_map = {c.key: c.value for c in configs}
|
||||
|
||||
# 填充结果,不存在的使用默认值
|
||||
for key in keys:
|
||||
if key in config_map:
|
||||
result[key] = config_map[key]
|
||||
elif key in cls.DEFAULT_CONFIGS:
|
||||
result[key] = cls.DEFAULT_CONFIGS[key]["value"]
|
||||
else:
|
||||
result[key] = None
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
|
||||
"""设置系统配置值"""
|
||||
@@ -111,6 +139,7 @@ class SystemConfigService:
|
||||
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
@@ -153,8 +182,8 @@ class SystemConfigService:
|
||||
for config in configs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def delete_config(db: Session, key: str) -> bool:
|
||||
@classmethod
|
||||
def delete_config(cls, db: Session, key: str) -> bool:
|
||||
"""删除系统配置"""
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if config:
|
||||
|
||||
@@ -16,6 +16,7 @@ from src.models.database import (
|
||||
ApiKey,
|
||||
RequestCandidate,
|
||||
StatsDaily,
|
||||
StatsDailyModel,
|
||||
StatsSummary,
|
||||
StatsUserDaily,
|
||||
Usage,
|
||||
@@ -219,6 +220,120 @@ class StatsAggregatorService:
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_model_stats(db: Session, date: datetime) -> list[StatsDailyModel]:
|
||||
"""聚合指定日期的模型维度统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的业务日期
|
||||
|
||||
Returns:
|
||||
StatsDailyModel 记录列表
|
||||
"""
|
||||
day_start, day_end = _get_business_day_range(date)
|
||||
|
||||
# 按模型分组统计
|
||||
model_stats = (
|
||||
db.query(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("total_requests"),
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
.group_by(Usage.model)
|
||||
.all()
|
||||
)
|
||||
|
||||
results = []
|
||||
for stat in model_stats:
|
||||
if not stat.model:
|
||||
continue
|
||||
|
||||
existing = (
|
||||
db.query(StatsDailyModel)
|
||||
.filter(and_(StatsDailyModel.date == day_start, StatsDailyModel.model == stat.model))
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
record = existing
|
||||
else:
|
||||
record = StatsDailyModel(
|
||||
id=str(uuid.uuid4()), date=day_start, model=stat.model
|
||||
)
|
||||
|
||||
record.total_requests = stat.total_requests or 0
|
||||
record.input_tokens = int(stat.input_tokens or 0)
|
||||
record.output_tokens = int(stat.output_tokens or 0)
|
||||
record.cache_creation_tokens = int(stat.cache_creation_tokens or 0)
|
||||
record.cache_read_tokens = int(stat.cache_read_tokens or 0)
|
||||
record.total_cost = float(stat.total_cost or 0)
|
||||
record.avg_response_time_ms = float(stat.avg_response_time or 0)
|
||||
|
||||
if not existing:
|
||||
db.add(record)
|
||||
results.append(record)
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"[StatsAggregator] 聚合日期 {date.date()} 模型统计完成: {len(results)} 个模型"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_daily_model_stats(db: Session, start_date: datetime, end_date: datetime) -> list[dict]:
|
||||
"""获取日期范围内的模型统计数据(优先使用预聚合)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
start_date: 开始日期 (UTC)
|
||||
end_date: 结束日期 (UTC)
|
||||
|
||||
Returns:
|
||||
模型统计数据列表
|
||||
"""
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||
|
||||
# 从预聚合表获取历史数据
|
||||
stats = (
|
||||
db.query(StatsDailyModel)
|
||||
.filter(and_(StatsDailyModel.date >= start_date, StatsDailyModel.date < end_date))
|
||||
.order_by(StatsDailyModel.date.asc(), StatsDailyModel.total_cost.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# 转换为字典格式,按日期分组
|
||||
result = []
|
||||
for stat in stats:
|
||||
# 转换日期为业务时区
|
||||
if stat.date.tzinfo is None:
|
||||
date_utc = stat.date.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
date_utc = stat.date.astimezone(timezone.utc)
|
||||
date_str = date_utc.astimezone(app_tz).date().isoformat()
|
||||
|
||||
result.append({
|
||||
"date": date_str,
|
||||
"model": stat.model,
|
||||
"requests": stat.total_requests,
|
||||
"tokens": (
|
||||
stat.input_tokens + stat.output_tokens +
|
||||
stat.cache_creation_tokens + stat.cache_read_tokens
|
||||
),
|
||||
"cost": stat.total_cost,
|
||||
"avg_response_time": stat.avg_response_time_ms / 1000.0 if stat.avg_response_time_ms else 0,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def aggregate_user_daily_stats(
|
||||
db: Session, user_id: str, date: datetime
|
||||
@@ -497,6 +612,7 @@ class StatsAggregatorService:
|
||||
current_date = start_date
|
||||
while current_date < today_local:
|
||||
StatsAggregatorService.aggregate_daily_stats(db, current_date)
|
||||
StatsAggregatorService.aggregate_daily_model_stats(db, current_date)
|
||||
count += 1
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecordParams:
|
||||
"""用量记录参数数据类,用于在内部方法间传递数据"""
|
||||
db: Session
|
||||
user: Optional[User]
|
||||
api_key: Optional[ApiKey]
|
||||
provider: str
|
||||
model: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int
|
||||
cache_read_input_tokens: int
|
||||
request_type: str
|
||||
api_format: Optional[str]
|
||||
is_stream: bool
|
||||
response_time_ms: Optional[int]
|
||||
first_byte_time_ms: Optional[int]
|
||||
status_code: int
|
||||
error_message: Optional[str]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
request_headers: Optional[Dict[str, Any]]
|
||||
request_body: Optional[Any]
|
||||
provider_request_headers: Optional[Dict[str, Any]]
|
||||
response_headers: Optional[Dict[str, Any]]
|
||||
response_body: Optional[Any]
|
||||
request_id: str
|
||||
provider_id: Optional[str]
|
||||
provider_endpoint_id: Optional[str]
|
||||
provider_api_key_id: Optional[str]
|
||||
status: str
|
||||
cache_ttl_minutes: Optional[int]
|
||||
use_tiered_pricing: bool
|
||||
target_model: Optional[str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""验证关键字段,确保数据完整性"""
|
||||
# Token 数量不能为负数
|
||||
if self.input_tokens < 0:
|
||||
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
|
||||
if self.output_tokens < 0:
|
||||
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
|
||||
if self.cache_creation_input_tokens < 0:
|
||||
raise ValueError(
|
||||
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
|
||||
)
|
||||
if self.cache_read_input_tokens < 0:
|
||||
raise ValueError(
|
||||
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
|
||||
)
|
||||
|
||||
# 响应时间不能为负数
|
||||
if self.response_time_ms is not None and self.response_time_ms < 0:
|
||||
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
|
||||
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
|
||||
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
|
||||
|
||||
# HTTP 状态码范围校验
|
||||
if not (100 <= self.status_code <= 599):
|
||||
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
|
||||
|
||||
# 状态值校验
|
||||
valid_statuses = {"pending", "streaming", "completed", "failed"}
|
||||
if self.status not in valid_statuses:
|
||||
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
|
||||
|
||||
|
||||
class UsageService:
|
||||
"""用量统计服务"""
|
||||
@@ -471,6 +537,97 @@ class UsageService:
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _prepare_usage_record(
|
||||
cls,
|
||||
params: UsageRecordParams,
|
||||
) -> Tuple[Dict[str, Any], float]:
|
||||
"""准备用量记录的共享逻辑
|
||||
|
||||
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
|
||||
- 获取费率倍数
|
||||
- 计算成本
|
||||
- 构建 Usage 参数
|
||||
|
||||
Args:
|
||||
params: 用量记录参数数据类
|
||||
|
||||
Returns:
|
||||
(usage_params 字典, total_cost 总成本)
|
||||
"""
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
params.db, params.provider_api_key_id, params.provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = params.status_code >= 400 or params.error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, _tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=params.db,
|
||||
provider=params.provider,
|
||||
model=params.model,
|
||||
input_tokens=params.input_tokens,
|
||||
output_tokens=params.output_tokens,
|
||||
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||
api_format=params.api_format,
|
||||
cache_ttl_minutes=params.cache_ttl_minutes,
|
||||
use_tiered_pricing=params.use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=params.db,
|
||||
user=params.user,
|
||||
api_key=params.api_key,
|
||||
provider=params.provider,
|
||||
model=params.model,
|
||||
input_tokens=params.input_tokens,
|
||||
output_tokens=params.output_tokens,
|
||||
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||
request_type=params.request_type,
|
||||
api_format=params.api_format,
|
||||
is_stream=params.is_stream,
|
||||
response_time_ms=params.response_time_ms,
|
||||
first_byte_time_ms=params.first_byte_time_ms,
|
||||
status_code=params.status_code,
|
||||
error_message=params.error_message,
|
||||
metadata=params.metadata,
|
||||
request_headers=params.request_headers,
|
||||
request_body=params.request_body,
|
||||
provider_request_headers=params.provider_request_headers,
|
||||
response_headers=params.response_headers,
|
||||
response_body=params.response_body,
|
||||
request_id=params.request_id,
|
||||
provider_id=params.provider_id,
|
||||
provider_endpoint_id=params.provider_endpoint_id,
|
||||
provider_api_key_id=params.provider_api_key_id,
|
||||
status=params.status,
|
||||
target_model=params.target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
|
||||
return usage_params, total_cost
|
||||
|
||||
@classmethod
|
||||
async def record_usage_async(
|
||||
cls,
|
||||
@@ -516,76 +673,25 @@ class UsageService:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
db, provider_api_key_id, provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = status_code >= 400 or error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=db,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
# 使用共享逻辑准备记录参数
|
||||
params = UsageRecordParams(
|
||||
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
use_tiered_pricing=use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
request_type=request_type,
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=metadata,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||
request_headers=request_headers, request_body=request_body,
|
||||
provider_request_headers=provider_request_headers,
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=request_id,
|
||||
provider_id=provider_id,
|
||||
response_headers=response_headers, response_body=response_body,
|
||||
request_id=request_id, provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
status=status,
|
||||
provider_api_key_id=provider_api_key_id, status=status,
|
||||
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||
target_model=target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
usage_params, _ = await cls._prepare_usage_record(params)
|
||||
|
||||
# 创建 Usage 记录
|
||||
usage = Usage(**usage_params)
|
||||
@@ -660,76 +766,25 @@ class UsageService:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
db, provider_api_key_id, provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = status_code >= 400 or error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, _tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=db,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
# 使用共享逻辑准备记录参数
|
||||
params = UsageRecordParams(
|
||||
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
use_tiered_pricing=use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
request_type=request_type,
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=metadata,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||
request_headers=request_headers, request_body=request_body,
|
||||
provider_request_headers=provider_request_headers,
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=request_id,
|
||||
provider_id=provider_id,
|
||||
response_headers=response_headers, response_body=response_body,
|
||||
request_id=request_id, provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
status=status,
|
||||
provider_api_key_id=provider_api_key_id, status=status,
|
||||
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||
target_model=target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
usage_params, total_cost = await cls._prepare_usage_record(params)
|
||||
|
||||
# 检查是否已存在相同 request_id 的记录
|
||||
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
||||
@@ -751,7 +806,7 @@ class UsageService:
|
||||
api_key = db.merge(api_key)
|
||||
|
||||
# 使用原子更新避免并发竞态条件
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy import func as sql_func, update
|
||||
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
||||
|
||||
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
||||
@@ -762,7 +817,7 @@ class UsageService:
|
||||
.values(
|
||||
used_usd=UserModel.used_usd + total_cost,
|
||||
total_usd=UserModel.total_usd + total_cost,
|
||||
updated_at=func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -776,8 +831,8 @@ class UsageService:
|
||||
total_requests=ApiKeyModel.total_requests + 1,
|
||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
last_used_at=sql_func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -787,8 +842,8 @@ class UsageService:
|
||||
.values(
|
||||
total_requests=ApiKeyModel.total_requests + 1,
|
||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
last_used_at=sql_func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1121,19 +1176,48 @@ class UsageService:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
|
||||
"""清理旧的使用记录"""
|
||||
def cleanup_old_usage_records(
|
||||
db: Session, days_to_keep: int = 90, batch_size: int = 1000
|
||||
) -> int:
|
||||
"""清理旧的使用记录(分批删除避免长事务锁定)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_to_keep: 保留天数,默认 90 天
|
||||
batch_size: 每批删除数量,默认 1000 条
|
||||
|
||||
Returns:
|
||||
删除的总记录数
|
||||
"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
total_deleted = 0
|
||||
|
||||
# 删除旧记录
|
||||
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
|
||||
while True:
|
||||
# 查询待删除的 ID(使用新索引 idx_usage_user_created)
|
||||
batch_ids = (
|
||||
db.query(Usage.id)
|
||||
.filter(Usage.created_at < cutoff_date)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
if not batch_ids:
|
||||
break
|
||||
|
||||
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
|
||||
# 批量删除
|
||||
deleted_count = (
|
||||
db.query(Usage)
|
||||
.filter(Usage.id.in_([row.id for row in batch_ids]))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
db.commit()
|
||||
total_deleted += deleted_count
|
||||
|
||||
return deleted
|
||||
logger.debug(f"清理使用记录: 本批删除 {deleted_count} 条")
|
||||
|
||||
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
|
||||
|
||||
return total_deleted
|
||||
|
||||
# ========== 请求状态追踪方法 ==========
|
||||
|
||||
@@ -1219,6 +1303,7 @@ class UsageService:
|
||||
error_message: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
target_model: Optional[str] = None,
|
||||
first_byte_time_ms: Optional[int] = None,
|
||||
) -> Optional[Usage]:
|
||||
"""
|
||||
快速更新使用记录状态
|
||||
@@ -1230,6 +1315,7 @@ class UsageService:
|
||||
error_message: 错误消息(仅在 failed 状态时使用)
|
||||
provider: 提供商名称(可选,streaming 状态时更新)
|
||||
target_model: 映射后的目标模型名(可选)
|
||||
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||
|
||||
Returns:
|
||||
更新后的 Usage 记录,如果未找到则返回 None
|
||||
@@ -1247,6 +1333,8 @@ class UsageService:
|
||||
usage.provider = provider
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
if first_byte_time_ms is not None:
|
||||
usage.first_byte_time_ms = first_byte_time_ms
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -457,26 +458,32 @@ class StreamUsageTracker:
|
||||
|
||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||
|
||||
# 更新状态为 streaming,同时更新 provider
|
||||
if self.request_id:
|
||||
try:
|
||||
from src.services.usage.service import UsageService
|
||||
UsageService.update_usage_status(
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
chunk_count = 0
|
||||
first_chunk_received = False
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunk_count += 1
|
||||
# 保存原始字节流(用于错误诊断)
|
||||
self.raw_chunks.append(chunk)
|
||||
|
||||
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
if self.request_id:
|
||||
try:
|
||||
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||
base_time = self.request_start_time or self.start_time
|
||||
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||
UsageService.update_usage_status(
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
# 返回原始块给客户端
|
||||
yield chunk
|
||||
|
||||
|
||||
@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
|
||||
# Python 3.10 及以下版本的兼容实现
|
||||
# 注意:这个简单实现不支持嵌套取消
|
||||
pass
|
||||
|
||||
|
||||
async def read_first_chunk_with_ttfb_timeout(
|
||||
byte_iterator: Any,
|
||||
timeout: float,
|
||||
request_id: str,
|
||||
provider_name: str,
|
||||
) -> tuple[bytes, Any]:
|
||||
"""
|
||||
读取流的首字节并应用 TTFB 超时检测
|
||||
|
||||
首字节超时(Time To First Byte)用于检测慢响应的 Provider,
|
||||
超时时触发故障转移到其他可用的 Provider。
|
||||
|
||||
Args:
|
||||
byte_iterator: 异步字节流迭代器
|
||||
timeout: TTFB 超时时间(秒)
|
||||
request_id: 请求 ID(用于日志)
|
||||
provider_name: Provider 名称(用于日志和异常)
|
||||
|
||||
Returns:
|
||||
(first_chunk, aiter): 首个字节块和异步迭代器
|
||||
|
||||
Raises:
|
||||
ProviderTimeoutException: 如果首字节超时
|
||||
"""
|
||||
from src.core.exceptions import ProviderTimeoutException
|
||||
|
||||
aiter = byte_iterator.__aiter__()
|
||||
|
||||
try:
|
||||
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
|
||||
return first_chunk, aiter
|
||||
except asyncio.TimeoutError:
|
||||
# 完整的资源清理:先关闭迭代器,再关闭底层响应
|
||||
await _cleanup_iterator_resources(aiter, request_id)
|
||||
logger.warning(
|
||||
f" [{request_id}] 流首字节超时 (TTFB): "
|
||||
f"Provider={provider_name}, timeout={timeout}s"
|
||||
)
|
||||
raise ProviderTimeoutException(
|
||||
provider_name=provider_name,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
|
||||
|
||||
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
|
||||
"""
|
||||
清理异步迭代器及其底层资源
|
||||
|
||||
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
|
||||
|
||||
Args:
|
||||
aiter: 异步迭代器
|
||||
request_id: 请求 ID(用于日志)
|
||||
"""
|
||||
# 1. 关闭迭代器本身
|
||||
if hasattr(aiter, "aclose"):
|
||||
try:
|
||||
await aiter.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
|
||||
|
||||
# 2. 关闭底层响应对象(httpx.Response)
|
||||
# 迭代器可能持有 _response 属性指向底层响应
|
||||
response = getattr(aiter, "_response", None)
|
||||
if response is not None and hasattr(response, "aclose"):
|
||||
try:
|
||||
await response.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
|
||||
|
||||
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes)
|
||||
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
|
||||
stream = getattr(aiter, "_stream", None)
|
||||
if stream is not None and hasattr(stream, "aclose"):
|
||||
try:
|
||||
await stream.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")
|
||||
|
||||
@@ -8,86 +8,116 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio
|
||||
class TestExtractCacheCreationTokens:
|
||||
"""测试 extract_cache_creation_tokens 函数"""
|
||||
|
||||
def test_new_format_only(self) -> None:
|
||||
"""测试只有新格式字段"""
|
||||
# === 嵌套格式测试(优先级最高)===
|
||||
|
||||
def test_nested_cache_creation_format(self) -> None:
|
||||
"""测试嵌套格式正常情况"""
|
||||
usage = {
|
||||
"cache_creation": {
|
||||
"ephemeral_5m_input_tokens": 456,
|
||||
"ephemeral_1h_input_tokens": 100,
|
||||
}
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 556
|
||||
|
||||
def test_nested_cache_creation_with_old_format_fallback(self) -> None:
|
||||
"""测试嵌套格式为 0 时回退到旧格式"""
|
||||
usage = {
|
||||
"cache_creation": {
|
||||
"ephemeral_5m_input_tokens": 0,
|
||||
"ephemeral_1h_input_tokens": 0,
|
||||
},
|
||||
"cache_creation_input_tokens": 549,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 549
|
||||
|
||||
def test_nested_has_priority_over_flat(self) -> None:
|
||||
"""测试嵌套格式优先于扁平格式"""
|
||||
usage = {
|
||||
"cache_creation": {
|
||||
"ephemeral_5m_input_tokens": 100,
|
||||
"ephemeral_1h_input_tokens": 200,
|
||||
},
|
||||
"claude_cache_creation_5_m_tokens": 999, # 应该被忽略
|
||||
"claude_cache_creation_1_h_tokens": 888, # 应该被忽略
|
||||
"cache_creation_input_tokens": 777, # 应该被忽略
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
# === 扁平格式测试(优先级第二)===
|
||||
|
||||
def test_flat_new_format_still_works(self) -> None:
|
||||
"""测试扁平新格式兼容性"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
def test_new_format_5m_only(self) -> None:
|
||||
"""测试只有 5 分钟缓存"""
|
||||
def test_flat_new_format_with_old_format_fallback(self) -> None:
|
||||
"""测试扁平格式为 0 时回退到旧格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 549,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 549
|
||||
|
||||
def test_flat_new_format_5m_only(self) -> None:
|
||||
"""测试只有 5 分钟扁平缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 150,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 150
|
||||
|
||||
def test_new_format_1h_only(self) -> None:
|
||||
"""测试只有 1 小时缓存"""
|
||||
def test_flat_new_format_1h_only(self) -> None:
|
||||
"""测试只有 1 小时扁平缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 250,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 250
|
||||
|
||||
# === 旧格式测试(优先级第三)===
|
||||
|
||||
def test_old_format_only(self) -> None:
|
||||
"""测试只有旧格式字段"""
|
||||
"""测试只有旧格式"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 500,
|
||||
"cache_creation_input_tokens": 549,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 500
|
||||
assert extract_cache_creation_tokens(usage) == 549
|
||||
|
||||
def test_both_formats_prefers_new(self) -> None:
|
||||
"""测试同时存在时优先使用新格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
# === 边界情况测试 ===
|
||||
|
||||
def test_empty_usage(self) -> None:
|
||||
"""测试空字典"""
|
||||
def test_no_cache_creation_tokens(self) -> None:
|
||||
"""测试没有任何缓存字段"""
|
||||
usage = {}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_all_zeros(self) -> None:
|
||||
"""测试所有字段都为 0"""
|
||||
def test_all_formats_zero(self) -> None:
|
||||
"""测试所有格式都为 0"""
|
||||
usage = {
|
||||
"cache_creation": {
|
||||
"ephemeral_5m_input_tokens": 0,
|
||||
"ephemeral_1h_input_tokens": 0,
|
||||
},
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||
"""测试新格式字段不存在时回退到旧格式"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 123,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 123
|
||||
|
||||
def test_new_format_zero_should_not_fallback(self) -> None:
|
||||
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 456,
|
||||
}
|
||||
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0)
|
||||
# 而不是 fallback 到旧格式(返回 456)
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_unrelated_fields_ignored(self) -> None:
|
||||
"""测试忽略无关字段"""
|
||||
usage = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 2000,
|
||||
"cache_read_input_tokens": 300,
|
||||
"claude_cache_creation_5_m_tokens": 50,
|
||||
"claude_cache_creation_1_h_tokens": 75,
|
||||
"cache_creation": {
|
||||
"ephemeral_5m_input_tokens": 50,
|
||||
"ephemeral_1h_input_tokens": 75,
|
||||
},
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 125
|
||||
|
||||
|
||||
Reference in New Issue
Block a user