From a12b43ce5cd6e2b13ba279a8e66f7c15c1c292dd Mon Sep 17 00:00:00 2001 From: fawney19 Date: Wed, 7 Jan 2026 19:53:32 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=B8=85=E7=90=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=AD=97=E6=AE=B5=E5=91=BD=E5=90=8D=E6=AD=A7?= =?UTF-8?q?=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误) - api_keys 表:删除 allowed_endpoints 字段(未使用的功能) - providers 表:删除 rate_limit 字段(与 rpm_limit 重复) - usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆) 同步更新前后端所有相关代码 --- ...66b7c4_cleanup_allowed_endpoints_fields.py | 73 +++++++++++++++++++ frontend/src/api/admin.ts | 3 +- frontend/src/api/auth.ts | 2 +- frontend/src/api/users.ts | 6 +- .../users/components/UserFormDialog.vue | 22 +++--- frontend/src/mocks/data.ts | 14 ++-- frontend/src/mocks/handler.ts | 2 +- frontend/src/views/admin/Users.vue | 6 +- src/api/admin/providers/routes.py | 3 - src/api/admin/system.py | 11 +-- src/api/admin/usage/routes.py | 16 ++-- src/api/admin/users/routes.py | 8 +- src/api/auth/routes.py | 2 +- src/api/base/models_service.py | 3 +- src/api/dashboard/routes.py | 4 +- src/api/public/models.py | 4 +- src/api/user_me/routes.py | 2 +- src/models/admin_requests.py | 4 +- src/models/api.py | 5 +- src/models/database.py | 6 +- src/services/cache/aware_scheduler.py | 34 +++------ src/services/system/stats_aggregator.py | 2 +- src/services/usage/service.py | 26 +++---- src/services/user/service.py | 4 +- 24 files changed, 155 insertions(+), 107 deletions(-) create mode 100644 alembic/versions/20260107_1120_02a45b66b7c4_cleanup_allowed_endpoints_fields.py diff --git a/alembic/versions/20260107_1120_02a45b66b7c4_cleanup_allowed_endpoints_fields.py b/alembic/versions/20260107_1120_02a45b66b7c4_cleanup_allowed_endpoints_fields.py new file mode 100644 index 0000000..920a5dc --- /dev/null +++ b/alembic/versions/20260107_1120_02a45b66b7c4_cleanup_allowed_endpoints_fields.py @@ -0,0 +1,73 @@ +"""cleanup ambiguous database fields + +Revision ID: 02a45b66b7c4 +Revises: ad55f1d008b7 +Create Date: 2026-01-07 11:20:12.684426+00:00 + +变更内容: +1. users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误) +2. api_keys 表:删除 allowed_endpoints 字段(未使用的功能) +3. providers 表:删除 rate_limit 字段(与 rpm_limit 功能重复,且未使用) +4. usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆) +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy import inspect + + +# revision identifiers, used by Alembic. +revision = '02a45b66b7c4' +down_revision = 'ad55f1d008b7' +branch_labels = None +depends_on = None + + +def _column_exists(table_name: str, column_name: str) -> bool: + """检查列是否存在""" + bind = op.get_bind() + inspector = inspect(bind) + columns = [col['name'] for col in inspector.get_columns(table_name)] + return column_name in columns + + +def upgrade() -> None: + """ + 1. users.allowed_endpoints -> allowed_api_formats(重命名) + 2. api_keys.allowed_endpoints 删除 + 3. providers.rate_limit 删除(与 rpm_limit 重复) + 4. usage.provider -> provider_name(重命名) + """ + # 1. users 表:重命名 allowed_endpoints 为 allowed_api_formats + if _column_exists('users', 'allowed_endpoints'): + op.alter_column('users', 'allowed_endpoints', new_column_name='allowed_api_formats') + + # 2. api_keys 表:删除 allowed_endpoints 字段 + if _column_exists('api_keys', 'allowed_endpoints'): + op.drop_column('api_keys', 'allowed_endpoints') + + # 3. providers 表:删除 rate_limit 字段(与 rpm_limit 功能重复) + if _column_exists('providers', 'rate_limit'): + op.drop_column('providers', 'rate_limit') + + # 4. usage 表:重命名 provider 为 provider_name + if _column_exists('usage', 'provider'): + op.alter_column('usage', 'provider', new_column_name='provider_name') + + +def downgrade() -> None: + """回滚:恢复原字段""" + # 4. usage 表:将 provider_name 改回 provider + if _column_exists('usage', 'provider_name'): + op.alter_column('usage', 'provider_name', new_column_name='provider') + + # 3. providers 表:恢复 rate_limit 字段 + if not _column_exists('providers', 'rate_limit'): + op.add_column('providers', sa.Column('rate_limit', sa.Integer(), nullable=True)) + + # 2. api_keys 表:恢复 allowed_endpoints 字段 + if not _column_exists('api_keys', 'allowed_endpoints'): + op.add_column('api_keys', sa.Column('allowed_endpoints', sa.JSON(), nullable=True)) + + # 1. users 表:将 allowed_api_formats 改回 allowed_endpoints + if _column_exists('users', 'allowed_api_formats'): + op.alter_column('users', 'allowed_api_formats', new_column_name='allowed_endpoints') diff --git a/frontend/src/api/admin.ts b/frontend/src/api/admin.ts index 45adee3..88ec952 100644 --- a/frontend/src/api/admin.ts +++ b/frontend/src/api/admin.ts @@ -22,7 +22,7 @@ export interface UserExport { password_hash: string role: string allowed_providers?: string[] | null - allowed_endpoints?: string[] | null + allowed_api_formats?: string[] | null allowed_models?: string[] | null model_capability_settings?: any quota_usd?: number | null @@ -40,7 +40,6 @@ export interface UserApiKeyExport { balance_used_usd?: number current_balance_usd?: number | null allowed_providers?: string[] | null - allowed_endpoints?: string[] | null allowed_api_formats?: string[] | null allowed_models?: string[] | null rate_limit?: number | null // null = 无限制 diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 43408bb..30b01ff 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -98,7 +98,7 @@ export interface User { used_usd?: number total_usd?: number allowed_providers?: string[] | null // 允许使用的提供商 ID 列表 - allowed_endpoints?: string[] | null // 允许使用的端点 ID 列表 + allowed_api_formats?: string[] | null // 允许使用的 API 格式列表 allowed_models?: string[] | null // 允许使用的模型名称列表 created_at: string last_login_at?: string diff --git a/frontend/src/api/users.ts b/frontend/src/api/users.ts index d30054c..c178f80 100644 --- a/frontend/src/api/users.ts +++ b/frontend/src/api/users.ts @@ -10,7 +10,7 @@ export interface User { used_usd: number total_usd: number allowed_providers: string[] | null // 允许使用的提供商 ID 列表 - allowed_endpoints: string[] | null // 允许使用的端点 ID 列表 + allowed_api_formats: string[] | null // 允许使用的 API 格式列表 allowed_models: string[] | null // 允许使用的模型名称列表 created_at: string updated_at?: string @@ -23,7 +23,7 @@ export interface CreateUserRequest { role?: 'admin' | 'user' quota_usd?: number | null allowed_providers?: string[] | null - allowed_endpoints?: string[] | null + allowed_api_formats?: string[] | null allowed_models?: string[] | null } @@ -34,7 +34,7 @@ export interface UpdateUserRequest { quota_usd?: number | null password?: string allowed_providers?: string[] | null - allowed_endpoints?: string[] | null + allowed_api_formats?: string[] | null allowed_models?: string[] | null } diff --git a/frontend/src/features/users/components/UserFormDialog.vue b/frontend/src/features/users/components/UserFormDialog.vue index c7e7d42..80e2c7a 100644 --- a/frontend/src/features/users/components/UserFormDialog.vue +++ b/frontend/src/features/users/components/UserFormDialog.vue @@ -273,8 +273,8 @@ class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors" @click="endpointDropdownOpen = !endpointDropdownOpen" > - - {{ form.allowed_endpoints.length ? `已选择 ${form.allowed_endpoints.length} 个` : '全部可用' }} + + {{ form.allowed_api_formats.length ? `已选择 ${form.allowed_api_formats.length} 个` : '全部可用' }} {{ format.label }} @@ -374,7 +374,7 @@ export interface UserFormData { role: 'admin' | 'user' is_active?: boolean allowed_providers?: string[] | null - allowed_endpoints?: string[] | null + allowed_api_formats?: string[] | null allowed_models?: string[] | null } @@ -414,7 +414,7 @@ const form = ref({ unlimited: false, is_active: true, allowed_providers: [] as string[], - allowed_endpoints: [] as string[], + allowed_api_formats: [] as string[], allowed_models: [] as string[] }) @@ -435,7 +435,7 @@ function resetForm() { unlimited: false, is_active: true, allowed_providers: [], - allowed_endpoints: [], + allowed_api_formats: [], allowed_models: [] } } @@ -454,7 +454,7 @@ function loadUserData() { unlimited: props.user.quota_usd == null, is_active: props.user.is_active ?? true, allowed_providers: props.user.allowed_providers || [], - allowed_endpoints: props.user.allowed_endpoints || [], + allowed_api_formats: props.user.allowed_api_formats || [], allowed_models: props.user.allowed_models || [] } } @@ -495,7 +495,7 @@ async function loadAccessControlOptions() { } // 切换选择 -function toggleSelection(field: 'allowed_providers' | 'allowed_endpoints' | 'allowed_models', value: string) { +function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'allowed_models', value: string) { const arr = form.value[field] const index = arr.indexOf(value) if (index === -1) { @@ -520,7 +520,7 @@ async function handleSubmit() { quota_usd: form.value.unlimited ? null : form.value.quota, role: form.value.role, allowed_providers: form.value.allowed_providers.length > 0 ? form.value.allowed_providers : null, - allowed_endpoints: form.value.allowed_endpoints.length > 0 ? form.value.allowed_endpoints : null, + allowed_api_formats: form.value.allowed_api_formats.length > 0 ? form.value.allowed_api_formats : null, allowed_models: form.value.allowed_models.length > 0 ? form.value.allowed_models : null } diff --git a/frontend/src/mocks/data.ts b/frontend/src/mocks/data.ts index 81702ad..e972ca4 100644 --- a/frontend/src/mocks/data.ts +++ b/frontend/src/mocks/data.ts @@ -22,7 +22,7 @@ export const MOCK_ADMIN_USER: User = { used_usd: 156.78, total_usd: 1234.56, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-01-01T00:00:00Z', last_login_at: new Date().toISOString() @@ -38,7 +38,7 @@ export const MOCK_NORMAL_USER: User = { used_usd: 45.32, total_usd: 245.32, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-06-01T00:00:00Z', last_login_at: new Date().toISOString() @@ -274,7 +274,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [ used_usd: 156.78, total_usd: 1234.56, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-01-01T00:00:00Z' }, @@ -288,7 +288,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [ used_usd: 45.32, total_usd: 245.32, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-06-01T00:00:00Z' }, @@ -302,7 +302,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [ used_usd: 23.45, total_usd: 123.45, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-03-15T00:00:00Z' }, @@ -316,7 +316,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [ used_usd: 89.12, total_usd: 589.12, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-02-20T00:00:00Z' }, @@ -330,7 +330,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [ used_usd: 30.00, total_usd: 30.00, allowed_providers: null, - allowed_endpoints: null, + allowed_api_formats: null, allowed_models: null, created_at: '2024-04-10T00:00:00Z' } diff --git a/frontend/src/mocks/handler.ts b/frontend/src/mocks/handler.ts index 15c3290..6ecb7f3 100644 --- a/frontend/src/mocks/handler.ts +++ b/frontend/src/mocks/handler.ts @@ -690,7 +690,7 @@ const mockHandlers: Record Promise= self.start_date) @@ -565,8 +565,8 @@ class AdminUsageByApiFormatAdapter(AdminApiAdapter): ) # 过滤掉 pending/streaming 状态的请求 query = query.filter(Usage.status.notin_(["pending", "streaming"])) - # 过滤掉 unknown/pending provider - query = query.filter(Usage.provider.notin_(["unknown", "pending"])) + # 过滤掉 unknown/pending provider_name + query = query.filter(Usage.provider_name.notin_(["unknown", "pending"])) # 只统计有 api_format 的记录 query = query.filter(Usage.api_format.isnot(None)) @@ -765,8 +765,8 @@ class AdminUsageRecordsAdapter(AdminApiAdapter): float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0 ) - # 提供商名称优先级:关联的 Provider 表 > usage.provider 字段 - provider_name = usage.provider + # 提供商名称优先级:关联的 Provider 表 > usage.provider_name 字段 + provider_name = usage.provider_name if usage.provider_id and str(usage.provider_id) in provider_map: provider_name = provider_map[str(usage.provider_id)] @@ -881,7 +881,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter): "name": api_key.name if api_key else None, "display": api_key.get_display_key() if api_key else None, }, - "provider": usage_record.provider, + "provider": usage_record.provider_name, "api_format": usage_record.api_format, "model": usage_record.model, "target_model": usage_record.target_model, @@ -934,7 +934,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter): # 尝试获取模型的阶梯配置(带来源信息) cost_service = ModelCostService(db) pricing_result = await cost_service.get_tiered_pricing_with_source_async( - usage_record.provider, usage_record.model + usage_record.provider_name, usage_record.model ) if not pricing_result: diff --git a/src/api/admin/users/routes.py b/src/api/admin/users/routes.py index 47a6e27..c4475cb 100644 --- a/src/api/admin/users/routes.py +++ b/src/api/admin/users/routes.py @@ -246,7 +246,7 @@ class AdminCreateUserAdapter(AdminApiAdapter): "username": user.username, "role": user.role.value, "allowed_providers": user.allowed_providers, - "allowed_endpoints": user.allowed_endpoints, + "allowed_api_formats": user.allowed_api_formats, "allowed_models": user.allowed_models, "quota_usd": user.quota_usd, "used_usd": user.used_usd, @@ -274,7 +274,7 @@ class AdminListUsersAdapter(AdminApiAdapter): "username": u.username, "role": u.role.value, "allowed_providers": u.allowed_providers, - "allowed_endpoints": u.allowed_endpoints, + "allowed_api_formats": u.allowed_api_formats, "allowed_models": u.allowed_models, "quota_usd": u.quota_usd, "used_usd": u.used_usd, @@ -309,7 +309,7 @@ class AdminGetUserAdapter(AdminApiAdapter): "username": user.username, "role": user.role.value, "allowed_providers": user.allowed_providers, - "allowed_endpoints": user.allowed_endpoints, + "allowed_api_formats": user.allowed_api_formats, "allowed_models": user.allowed_models, "quota_usd": user.quota_usd, "used_usd": user.used_usd, @@ -375,7 +375,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter): "username": user.username, "role": user.role.value, "allowed_providers": user.allowed_providers, - "allowed_endpoints": user.allowed_endpoints, + "allowed_api_formats": user.allowed_api_formats, "allowed_models": user.allowed_models, "quota_usd": user.quota_usd, "used_usd": user.used_usd, diff --git a/src/api/auth/routes.py b/src/api/auth/routes.py index e50eb59..cd5c256 100644 --- a/src/api/auth/routes.py +++ b/src/api/auth/routes.py @@ -528,7 +528,7 @@ class AuthCurrentUserAdapter(AuthenticatedApiAdapter): "used_usd": user.used_usd, "total_usd": user.total_usd, "allowed_providers": user.allowed_providers, - "allowed_endpoints": user.allowed_endpoints, + "allowed_api_formats": user.allowed_api_formats, "allowed_models": user.allowed_models, "created_at": user.created_at.isoformat(), "last_login_at": user.last_login_at.isoformat() if user.last_login_at else None, diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py index 1df1594..69ceacc 100644 --- a/src/api/base/models_service.py +++ b/src/api/base/models_service.py @@ -143,12 +143,13 @@ class AccessRestrictions: allowed_api_formats = api_key.allowed_api_formats # 如果 API Key 没有限制,检查 User 的限制 - # 注意: User 没有 allowed_api_formats 字段 if user: if allowed_providers is None and user.allowed_providers is not None: allowed_providers = user.allowed_providers if allowed_models is None and user.allowed_models is not None: allowed_models = user.allowed_models + if allowed_api_formats is None and user.allowed_api_formats is not None: + allowed_api_formats = user.allowed_api_formats return cls( allowed_providers=allowed_providers, diff --git a/src/api/dashboard/routes.py b/src/api/dashboard/routes.py index 6360b1d..a4e5b8a 100644 --- a/src/api/dashboard/routes.py +++ b/src/api/dashboard/routes.py @@ -766,7 +766,7 @@ class DashboardProviderStatusAdapter(DashboardAdapter): for provider in providers: count = ( db.query(func.count(Usage.id)) - .filter(and_(Usage.provider == provider.name, Usage.created_at >= since)) + .filter(and_(Usage.provider_name == provider.name, Usage.created_at >= since)) .scalar() ) entries.append( @@ -854,7 +854,7 @@ class DashboardDailyStatsAdapter(DashboardAdapter): .scalar() or 0 ) today_unique_providers = ( - db.query(func.count(func.distinct(Usage.provider))) + db.query(func.count(func.distinct(Usage.provider_name))) .filter(Usage.created_at >= today) .scalar() or 0 ) diff --git a/src/api/public/models.py b/src/api/public/models.py index 949846f..f9f2243 100644 --- a/src/api/public/models.py +++ b/src/api/public/models.py @@ -126,7 +126,9 @@ def _filter_formats_by_restrictions( """ if restrictions.allowed_api_formats is None: return formats, None - filtered = [f for f in formats if f in restrictions.allowed_api_formats] + # 统一转为大写比较,兼容数据库中存储的大小写 + allowed_upper = {f.upper() for f in restrictions.allowed_api_formats} + filtered = [f for f in formats if f.upper() in allowed_upper] if not filtered: logger.info(f"[Models] API Key 不允许访问格式 {api_format}") return [], _build_empty_list_response(api_format) diff --git a/src/api/user_me/routes.py b/src/api/user_me/routes.py index 86eddd1..a81f930 100644 --- a/src/api/user_me/routes.py +++ b/src/api/user_me/routes.py @@ -847,7 +847,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter): "records": [ { "id": r.id, - "provider": r.provider, + "provider": r.provider_name, "model": r.model, "target_model": r.target_model, # 映射后的目标模型名 "api_format": r.api_format, diff --git a/src/models/admin_requests.py b/src/models/admin_requests.py index 1510151..3cef538 100644 --- a/src/models/admin_requests.py +++ b/src/models/admin_requests.py @@ -71,7 +71,6 @@ class CreateProviderRequest(BaseModel): rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制") provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)") is_active: Optional[bool] = Field(True, description="是否启用") - rate_limit: Optional[int] = Field(None, ge=0, description="速率限制") concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制") config: Optional[Dict[str, Any]] = Field(None, description="其他配置") @@ -174,7 +173,6 @@ class UpdateProviderRequest(BaseModel): rpm_limit: Optional[int] = Field(None, ge=0) provider_priority: Optional[int] = Field(None, ge=0, le=1000) is_active: Optional[bool] = None - rate_limit: Optional[int] = Field(None, ge=0) concurrent_limit: Optional[int] = Field(None, ge=0) config: Optional[Dict[str, Any]] = None @@ -322,7 +320,7 @@ class UpdateUserRequest(BaseModel): is_active: Optional[bool] = None role: Optional[str] = None allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表") - allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表") + allowed_api_formats: Optional[List[str]] = Field(None, description="允许使用的 API 格式列表") allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表") @field_validator("username") diff --git a/src/models/api.py b/src/models/api.py index 03b6397..5521154 100644 --- a/src/models/api.py +++ b/src/models/api.py @@ -293,7 +293,7 @@ class UpdateUserRequest(BaseModel): password: Optional[str] = None role: Optional[UserRole] = None allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表 - allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表 + allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表 allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表 quota_usd: Optional[float] = None is_active: Optional[bool] = None @@ -316,7 +316,6 @@ class CreateApiKeyRequest(BaseModel): name: Optional[str] = None allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表 - allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表 allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表 allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表 rate_limit: Optional[int] = None # None = 无限制 @@ -339,7 +338,7 @@ class UserResponse(BaseModel): username: str role: UserRole allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表 - allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表 + allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表 allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表 quota_usd: float used_usd: float diff --git a/src/models/database.py b/src/models/database.py index fdf0815..35df330 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -72,7 +72,7 @@ class User(Base): # 访问限制(NULL 表示不限制,允许访问所有资源) allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表 - allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表 + allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表 allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表 # Key 能力配置 @@ -165,7 +165,6 @@ class ApiKey(Base): # 访问限制(NULL 表示不限制,允许访问所有资源) allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表 - allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表 allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表 allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表 rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制,None = 无限制 @@ -272,7 +271,7 @@ class Usage(Base): # 请求信息 request_id = Column(String(100), unique=True, index=True, nullable=False) - provider = Column(String(100), nullable=False) + provider_name = Column(String(100), nullable=False) # Provider 名称(非外键) model = Column(String(100), nullable=False) target_model = Column(String(100), nullable=True, comment="映射后的目标模型名(若无映射则为空)") @@ -554,7 +553,6 @@ class Provider(Base): is_active = Column(Boolean, default=True, nullable=False) # 限制 - rate_limit = Column(Integer, nullable=True) # 每分钟请求限制 concurrent_limit = Column(Integer, nullable=True) # 并发请求限制 # 配置 diff --git a/src/services/cache/aware_scheduler.py b/src/services/cache/aware_scheduler.py index f3628a0..3257240 100644 --- a/src/services/cache/aware_scheduler.py +++ b/src/services/cache/aware_scheduler.py @@ -486,11 +486,10 @@ class CacheAwareScheduler: user_api_key: 用户 API Key 对象(可能包含 user relationship) Returns: - 包含 allowed_providers, allowed_endpoints, allowed_models 的字典 + 包含 allowed_providers, allowed_models, allowed_api_formats 的字典 """ result = { "allowed_providers": None, - "allowed_endpoints": None, "allowed_models": None, "allowed_api_formats": None, } @@ -534,20 +533,16 @@ class CacheAwareScheduler: user_api_key.allowed_providers, user.allowed_providers if user else None ) - # 合并 allowed_endpoints - result["allowed_endpoints"] = merge_restrictions( - user_api_key.allowed_endpoints if hasattr(user_api_key, "allowed_endpoints") else None, - user.allowed_endpoints if user else None, - ) - # 合并 allowed_models result["allowed_models"] = merge_restrictions( user_api_key.allowed_models, user.allowed_models if user else None ) - # API 格式仅从 ApiKey 获取(User 不设置此限制) - if user_api_key.allowed_api_formats: - result["allowed_api_formats"] = set(user_api_key.allowed_api_formats) + # 合并 allowed_api_formats + result["allowed_api_formats"] = merge_restrictions( + user_api_key.allowed_api_formats, + user.allowed_api_formats if user else None + ) return result @@ -607,12 +602,13 @@ class CacheAwareScheduler: restrictions = self._get_effective_restrictions(user_api_key) allowed_api_formats = restrictions["allowed_api_formats"] allowed_providers = restrictions["allowed_providers"] - allowed_endpoints = restrictions["allowed_endpoints"] allowed_models = restrictions["allowed_models"] # 0.1 检查 API 格式是否被允许 if allowed_api_formats is not None: - if target_format.value not in allowed_api_formats: + # 统一转为大写比较,兼容数据库中存储的大小写 + allowed_upper = {f.upper() for f in allowed_api_formats} + if target_format.value.upper() not in allowed_upper: logger.debug( f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, " f"允许的格式: {allowed_api_formats}" @@ -659,7 +655,7 @@ class CacheAwareScheduler: if not providers: return [], global_model_id - # 2. 构建候选列表(传入 allowed_endpoints、is_stream 和 capability_requirements 用于过滤) + # 2. 构建候选列表(传入 is_stream 和 capability_requirements 用于过滤) candidates = await self._build_candidates( db=db, providers=providers, @@ -668,7 +664,6 @@ class CacheAwareScheduler: resolved_model_name=resolved_model_name, affinity_key=affinity_key, max_candidates=max_candidates, - allowed_endpoints=allowed_endpoints, is_stream=is_stream, capability_requirements=capability_requirements, ) @@ -905,7 +900,6 @@ class CacheAwareScheduler: affinity_key: Optional[str], resolved_model_name: Optional[str] = None, max_candidates: Optional[int] = None, - allowed_endpoints: Optional[set] = None, is_stream: bool = False, capability_requirements: Optional[Dict[str, bool]] = None, ) -> List[ProviderCandidate]: @@ -920,7 +914,6 @@ class CacheAwareScheduler: affinity_key: 亲和性标识符(通常为API Key ID) resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验) max_candidates: 最大候选数 - allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制) is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider capability_requirements: 能力需求(可选) @@ -949,13 +942,6 @@ class CacheAwareScheduler: if not endpoint.is_active or endpoint_format_str != target_format.value: continue - # 检查 Endpoint 是否在允许列表中 - if allowed_endpoints is not None and endpoint.id not in allowed_endpoints: - logger.debug( - f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过" - ) - continue - # 获取活跃的 Key 并按 internal_priority + 负载均衡排序 active_keys = [key for key in endpoint.api_keys if key.is_active] # 检查是否所有 Key 都是 TTL=0(轮换模式) diff --git a/src/services/system/stats_aggregator.py b/src/services/system/stats_aggregator.py index bc9415e..3cd16fc 100644 --- a/src/services/system/stats_aggregator.py +++ b/src/services/system/stats_aggregator.py @@ -144,7 +144,7 @@ class StatsAggregatorService: or 0 ) unique_providers = ( - db.query(func.count(func.distinct(Usage.provider))) + db.query(func.count(func.distinct(Usage.provider_name))) .filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end)) .scalar() or 0 diff --git a/src/services/usage/service.py b/src/services/usage/service.py index 42f175e..d8d71d6 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -309,7 +309,7 @@ class UsageService: "user_id": user.id if user else None, "api_key_id": api_key.id if api_key else None, "request_id": request_id, - "provider": provider, + "provider_name": provider, "model": model, "target_model": target_model, "provider_id": provider_id, @@ -479,7 +479,7 @@ class UsageService: ) -> None: """更新已存在的 Usage 记录(内部方法)""" # 更新关键字段 - existing_usage.provider = usage_params["provider"] + existing_usage.provider_name = usage_params["provider_name"] existing_usage.status = usage_params["status"] existing_usage.status_code = usage_params["status_code"] existing_usage.error_message = usage_params["error_message"] @@ -1092,7 +1092,7 @@ class UsageService: # 汇总查询 summary = db.query( date_func.label("period"), - Usage.provider, + Usage.provider_name, Usage.model, func.count(Usage.id).label("requests"), func.sum(Usage.input_tokens).label("input_tokens"), @@ -1111,12 +1111,12 @@ class UsageService: if end_date: summary = summary.filter(Usage.created_at <= end_date) - summary = summary.group_by(date_func, Usage.provider, Usage.model).all() + summary = summary.group_by(date_func, Usage.provider_name, Usage.model).all() return [ { "period": row.period, - "provider": row.provider, + "provider": row.provider_name, "model": row.model, "requests": row.requests, "input_tokens": row.input_tokens, @@ -1445,7 +1445,7 @@ class UsageService: user_id=user.id if user else None, api_key_id=api_key.id if api_key else None, request_id=request_id, - provider="pending", # 尚未确定 provider + provider_name="pending", # 尚未确定 provider model=model, input_tokens=0, output_tokens=0, @@ -1508,12 +1508,12 @@ class UsageService: if error_message: usage.error_message = error_message if provider: - usage.provider = provider - elif status == "streaming" and usage.provider == "pending": - # 状态变为 streaming 但 provider 仍为 pending,记录警告 + usage.provider_name = provider + elif status == "streaming" and usage.provider_name == "pending": + # 状态变为 streaming 但 provider_name 仍为 pending,记录警告 logger.warning( - f"状态更新为 streaming 但 provider 为空: request_id={request_id}, " - f"当前 provider={usage.provider}" + f"状态更新为 streaming 但 provider_name 为空: request_id={request_id}, " + f"当前 provider_name={usage.provider_name}" ) if target_model: usage.target_model = target_model @@ -1679,7 +1679,7 @@ class UsageService: from src.models.database import ProviderAPIKey query = query.add_columns( - Usage.provider, + Usage.provider_name, ProviderAPIKey.name.label("api_key_name"), ).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id) @@ -1731,7 +1731,7 @@ class UsageService: "first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB) } if include_admin_fields: - item["provider"] = r.provider + item["provider"] = r.provider_name item["api_key_name"] = r.api_key_name result.append(item) diff --git a/src/services/user/service.py b/src/services/user/service.py index 5bbf6da..8e70e9a 100644 --- a/src/services/user/service.py +++ b/src/services/user/service.py @@ -182,12 +182,12 @@ class UserService: "role", # 访问限制字段 "allowed_providers", - "allowed_endpoints", + "allowed_api_formats", "allowed_models", ] # 允许设置为 None 的字段(表示无限制) - nullable_fields = ["quota_usd", "allowed_providers", "allowed_endpoints", "allowed_models"] + nullable_fields = ["quota_usd", "allowed_providers", "allowed_api_formats", "allowed_models"] for field, value in kwargs.items(): if field not in updatable_fields: