mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
refactor: 清理数据库字段命名歧义
- users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误) - api_keys 表:删除 allowed_endpoints 字段(未使用的功能) - providers 表:删除 rate_limit 字段(与 rpm_limit 重复) - usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆) 同步更新前后端所有相关代码
This commit is contained in:
34
src/services/cache/aware_scheduler.py
vendored
34
src/services/cache/aware_scheduler.py
vendored
@@ -486,11 +486,10 @@ class CacheAwareScheduler:
|
||||
user_api_key: 用户 API Key 对象(可能包含 user relationship)
|
||||
|
||||
Returns:
|
||||
包含 allowed_providers, allowed_endpoints, allowed_models 的字典
|
||||
包含 allowed_providers, allowed_models, allowed_api_formats 的字典
|
||||
"""
|
||||
result = {
|
||||
"allowed_providers": None,
|
||||
"allowed_endpoints": None,
|
||||
"allowed_models": None,
|
||||
"allowed_api_formats": None,
|
||||
}
|
||||
@@ -534,20 +533,16 @@ class CacheAwareScheduler:
|
||||
user_api_key.allowed_providers, user.allowed_providers if user else None
|
||||
)
|
||||
|
||||
# 合并 allowed_endpoints
|
||||
result["allowed_endpoints"] = merge_restrictions(
|
||||
user_api_key.allowed_endpoints if hasattr(user_api_key, "allowed_endpoints") else None,
|
||||
user.allowed_endpoints if user else None,
|
||||
)
|
||||
|
||||
# 合并 allowed_models
|
||||
result["allowed_models"] = merge_restrictions(
|
||||
user_api_key.allowed_models, user.allowed_models if user else None
|
||||
)
|
||||
|
||||
# API 格式仅从 ApiKey 获取(User 不设置此限制)
|
||||
if user_api_key.allowed_api_formats:
|
||||
result["allowed_api_formats"] = set(user_api_key.allowed_api_formats)
|
||||
# 合并 allowed_api_formats
|
||||
result["allowed_api_formats"] = merge_restrictions(
|
||||
user_api_key.allowed_api_formats,
|
||||
user.allowed_api_formats if user else None
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -607,12 +602,13 @@ class CacheAwareScheduler:
|
||||
restrictions = self._get_effective_restrictions(user_api_key)
|
||||
allowed_api_formats = restrictions["allowed_api_formats"]
|
||||
allowed_providers = restrictions["allowed_providers"]
|
||||
allowed_endpoints = restrictions["allowed_endpoints"]
|
||||
allowed_models = restrictions["allowed_models"]
|
||||
|
||||
# 0.1 检查 API 格式是否被允许
|
||||
if allowed_api_formats is not None:
|
||||
if target_format.value not in allowed_api_formats:
|
||||
# 统一转为大写比较,兼容数据库中存储的大小写
|
||||
allowed_upper = {f.upper() for f in allowed_api_formats}
|
||||
if target_format.value.upper() not in allowed_upper:
|
||||
logger.debug(
|
||||
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||
f"允许的格式: {allowed_api_formats}"
|
||||
@@ -659,7 +655,7 @@ class CacheAwareScheduler:
|
||||
if not providers:
|
||||
return [], global_model_id
|
||||
|
||||
# 2. 构建候选列表(传入 allowed_endpoints、is_stream 和 capability_requirements 用于过滤)
|
||||
# 2. 构建候选列表(传入 is_stream 和 capability_requirements 用于过滤)
|
||||
candidates = await self._build_candidates(
|
||||
db=db,
|
||||
providers=providers,
|
||||
@@ -668,7 +664,6 @@ class CacheAwareScheduler:
|
||||
resolved_model_name=resolved_model_name,
|
||||
affinity_key=affinity_key,
|
||||
max_candidates=max_candidates,
|
||||
allowed_endpoints=allowed_endpoints,
|
||||
is_stream=is_stream,
|
||||
capability_requirements=capability_requirements,
|
||||
)
|
||||
@@ -905,7 +900,6 @@ class CacheAwareScheduler:
|
||||
affinity_key: Optional[str],
|
||||
resolved_model_name: Optional[str] = None,
|
||||
max_candidates: Optional[int] = None,
|
||||
allowed_endpoints: Optional[set] = None,
|
||||
is_stream: bool = False,
|
||||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||||
) -> List[ProviderCandidate]:
|
||||
@@ -920,7 +914,6 @@ class CacheAwareScheduler:
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||||
max_candidates: 最大候选数
|
||||
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
||||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||||
capability_requirements: 能力需求(可选)
|
||||
|
||||
@@ -949,13 +942,6 @@ class CacheAwareScheduler:
|
||||
if not endpoint.is_active or endpoint_format_str != target_format.value:
|
||||
continue
|
||||
|
||||
# 检查 Endpoint 是否在允许列表中
|
||||
if allowed_endpoints is not None and endpoint.id not in allowed_endpoints:
|
||||
logger.debug(
|
||||
f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过"
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
||||
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
||||
# 检查是否所有 Key 都是 TTL=0(轮换模式)
|
||||
|
||||
@@ -144,7 +144,7 @@ class StatsAggregatorService:
|
||||
or 0
|
||||
)
|
||||
unique_providers = (
|
||||
db.query(func.count(func.distinct(Usage.provider)))
|
||||
db.query(func.count(func.distinct(Usage.provider_name)))
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
.scalar()
|
||||
or 0
|
||||
|
||||
@@ -309,7 +309,7 @@ class UsageService:
|
||||
"user_id": user.id if user else None,
|
||||
"api_key_id": api_key.id if api_key else None,
|
||||
"request_id": request_id,
|
||||
"provider": provider,
|
||||
"provider_name": provider,
|
||||
"model": model,
|
||||
"target_model": target_model,
|
||||
"provider_id": provider_id,
|
||||
@@ -479,7 +479,7 @@ class UsageService:
|
||||
) -> None:
|
||||
"""更新已存在的 Usage 记录(内部方法)"""
|
||||
# 更新关键字段
|
||||
existing_usage.provider = usage_params["provider"]
|
||||
existing_usage.provider_name = usage_params["provider_name"]
|
||||
existing_usage.status = usage_params["status"]
|
||||
existing_usage.status_code = usage_params["status_code"]
|
||||
existing_usage.error_message = usage_params["error_message"]
|
||||
@@ -1092,7 +1092,7 @@ class UsageService:
|
||||
# 汇总查询
|
||||
summary = db.query(
|
||||
date_func.label("period"),
|
||||
Usage.provider,
|
||||
Usage.provider_name,
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
@@ -1111,12 +1111,12 @@ class UsageService:
|
||||
if end_date:
|
||||
summary = summary.filter(Usage.created_at <= end_date)
|
||||
|
||||
summary = summary.group_by(date_func, Usage.provider, Usage.model).all()
|
||||
summary = summary.group_by(date_func, Usage.provider_name, Usage.model).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"period": row.period,
|
||||
"provider": row.provider,
|
||||
"provider": row.provider_name,
|
||||
"model": row.model,
|
||||
"requests": row.requests,
|
||||
"input_tokens": row.input_tokens,
|
||||
@@ -1445,7 +1445,7 @@ class UsageService:
|
||||
user_id=user.id if user else None,
|
||||
api_key_id=api_key.id if api_key else None,
|
||||
request_id=request_id,
|
||||
provider="pending", # 尚未确定 provider
|
||||
provider_name="pending", # 尚未确定 provider
|
||||
model=model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
@@ -1508,12 +1508,12 @@ class UsageService:
|
||||
if error_message:
|
||||
usage.error_message = error_message
|
||||
if provider:
|
||||
usage.provider = provider
|
||||
elif status == "streaming" and usage.provider == "pending":
|
||||
# 状态变为 streaming 但 provider 仍为 pending,记录警告
|
||||
usage.provider_name = provider
|
||||
elif status == "streaming" and usage.provider_name == "pending":
|
||||
# 状态变为 streaming 但 provider_name 仍为 pending,记录警告
|
||||
logger.warning(
|
||||
f"状态更新为 streaming 但 provider 为空: request_id={request_id}, "
|
||||
f"当前 provider={usage.provider}"
|
||||
f"状态更新为 streaming 但 provider_name 为空: request_id={request_id}, "
|
||||
f"当前 provider_name={usage.provider_name}"
|
||||
)
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
@@ -1679,7 +1679,7 @@ class UsageService:
|
||||
from src.models.database import ProviderAPIKey
|
||||
|
||||
query = query.add_columns(
|
||||
Usage.provider,
|
||||
Usage.provider_name,
|
||||
ProviderAPIKey.name.label("api_key_name"),
|
||||
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
|
||||
@@ -1731,7 +1731,7 @@ class UsageService:
|
||||
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
}
|
||||
if include_admin_fields:
|
||||
item["provider"] = r.provider
|
||||
item["provider"] = r.provider_name
|
||||
item["api_key_name"] = r.api_key_name
|
||||
result.append(item)
|
||||
|
||||
|
||||
@@ -182,12 +182,12 @@ class UserService:
|
||||
"role",
|
||||
# 访问限制字段
|
||||
"allowed_providers",
|
||||
"allowed_endpoints",
|
||||
"allowed_api_formats",
|
||||
"allowed_models",
|
||||
]
|
||||
|
||||
# 允许设置为 None 的字段(表示无限制)
|
||||
nullable_fields = ["quota_usd", "allowed_providers", "allowed_endpoints", "allowed_models"]
|
||||
nullable_fields = ["quota_usd", "allowed_providers", "allowed_api_formats", "allowed_models"]
|
||||
|
||||
for field, value in kwargs.items():
|
||||
if field not in updatable_fields:
|
||||
|
||||
Reference in New Issue
Block a user