mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
- 在 UsageService 新增 get_active_requests 方法,统一处理活跃请求查询 - 支持自动清理超时的 pending 请求(默认 5 分钟) - admin 和 user 接口均复用该方法,减少重复代码 - 支持按 ID 列表查询或查询所有活跃请求
1086 lines
42 KiB
Python
1086 lines
42 KiB
Python
"""用户个人 API 端点。"""
|
||
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||
from pydantic import ValidationError
|
||
from sqlalchemy import and_, func
|
||
from sqlalchemy.orm import Session, joinedload
|
||
|
||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.core.crypto import crypto_service
|
||
from src.core.exceptions import ForbiddenException, InvalidRequestException, NotFoundException, translate_pydantic_error
|
||
from src.core.logger import logger
|
||
from src.database import get_db
|
||
from src.models.api import (
|
||
ChangePasswordRequest,
|
||
CreateMyApiKeyRequest,
|
||
UpdateApiKeyProvidersRequest,
|
||
UpdatePreferencesRequest,
|
||
UpdateProfileRequest,
|
||
)
|
||
from src.models.database import ApiKey, Provider, Usage, User
|
||
from src.services.usage.service import UsageService
|
||
from src.services.user.apikey import ApiKeyService
|
||
from src.services.user.preference import PreferenceService
|
||
|
||
|
||
router = APIRouter(prefix="/api/users/me", tags=["User Profile"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
@router.get("")
|
||
async def get_my_profile(request: Request, db: Session = Depends(get_db)):
|
||
"""获取当前用户完整信息(包含偏好设置)"""
|
||
adapter = MeProfileAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("")
|
||
async def update_my_profile(request: Request, db: Session = Depends(get_db)):
|
||
adapter = UpdateProfileAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.patch("/password")
|
||
async def change_my_password(request: Request, db: Session = Depends(get_db)):
|
||
"""Change current user's password"""
|
||
adapter = ChangePasswordAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ============== API密钥管理 ==============
|
||
|
||
|
||
@router.get("/api-keys")
|
||
async def list_my_api_keys(request: Request, db: Session = Depends(get_db)):
|
||
adapter = ListMyApiKeysAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/api-keys")
|
||
async def create_my_api_key(request: Request, db: Session = Depends(get_db)):
|
||
adapter = CreateMyApiKeyAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/api-keys/{key_id}")
|
||
async def get_my_api_key(
|
||
key_id: str,
|
||
request: Request,
|
||
include_key: bool = Query(False, description="Include full decrypted key in response"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""Get API key detail, optionally include full key"""
|
||
if include_key:
|
||
adapter = GetMyFullKeyAdapter(key_id=key_id)
|
||
else:
|
||
adapter = GetMyApiKeyDetailAdapter(key_id=key_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.delete("/api-keys/{key_id}")
|
||
async def delete_my_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||
adapter = DeleteMyApiKeyAdapter(key_id=key_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.patch("/api-keys/{key_id}")
|
||
async def toggle_my_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||
"""Toggle API key active status"""
|
||
adapter = ToggleMyApiKeyAdapter(key_id=key_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ============== 使用统计 ==============
|
||
|
||
|
||
@router.get("/usage")
|
||
async def get_my_usage(
|
||
request: Request,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/usage/active")
|
||
async def get_my_active_requests(
|
||
request: Request,
|
||
ids: Optional[str] = Query(None, description="Comma-separated request IDs to query"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取用户活跃请求状态(用于轮询更新)"""
|
||
adapter = GetActiveRequestsAdapter(ids=ids)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/providers")
|
||
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
|
||
adapter = ListAvailableProvidersAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/endpoint-status")
|
||
async def get_endpoint_status(request: Request, db: Session = Depends(get_db)):
|
||
"""获取端点状态(简化版,不包含敏感信息)"""
|
||
adapter = GetEndpointStatusAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ============== API密钥与提供商关联 ==============
|
||
|
||
|
||
# UpdateApiKeyProvidersRequest 已移至 src/models/api.py
|
||
|
||
|
||
@router.put("/api-keys/{api_key_id}/providers")
|
||
async def update_api_key_providers(
|
||
api_key_id: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
adapter = UpdateApiKeyProvidersAdapter(api_key_id=api_key_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/api-keys/{api_key_id}/capabilities")
|
||
async def update_api_key_capabilities(
|
||
api_key_id: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""更新 API Key 的强制能力配置"""
|
||
adapter = UpdateApiKeyCapabilitiesAdapter(api_key_id=api_key_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ============== 偏好设置 ==============
|
||
|
||
|
||
@router.get("/preferences")
|
||
async def get_my_preferences(request: Request, db: Session = Depends(get_db)):
|
||
adapter = GetPreferencesAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/preferences")
|
||
async def update_my_preferences(request: Request, db: Session = Depends(get_db)):
|
||
adapter = UpdatePreferencesAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/model-capabilities")
|
||
async def get_model_capability_settings(request: Request, db: Session = Depends(get_db)):
|
||
"""获取用户的模型能力配置"""
|
||
adapter = GetModelCapabilitySettingsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/model-capabilities")
|
||
async def update_model_capability_settings(request: Request, db: Session = Depends(get_db)):
|
||
"""更新用户的模型能力配置"""
|
||
adapter = UpdateModelCapabilitySettingsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ============== Pipeline 适配器 ==============
|
||
|
||
|
||
class MeProfileAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
return PreferenceService.get_user_with_preferences(context.db, context.user.id)
|
||
|
||
|
||
class UpdateProfileAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
request = UpdateProfileRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
if request.email:
|
||
existing = (
|
||
db.query(User).filter(User.email == request.email, User.id != user.id).first()
|
||
)
|
||
if existing:
|
||
raise InvalidRequestException("邮箱已被使用")
|
||
user.email = request.email
|
||
|
||
if request.username:
|
||
existing = (
|
||
db.query(User).filter(User.username == request.username, User.id != user.id).first()
|
||
)
|
||
if existing:
|
||
raise InvalidRequestException("用户名已被使用")
|
||
user.username = request.username
|
||
|
||
user.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
db.refresh(user)
|
||
return {"message": "个人信息更新成功"}
|
||
|
||
|
||
class ChangePasswordAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
request = ChangePasswordRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
if not user.verify_password(request.old_password):
|
||
raise InvalidRequestException("旧密码错误")
|
||
if len(request.new_password) < 6:
|
||
raise InvalidRequestException("密码长度至少6位")
|
||
|
||
user.set_password(request.new_password)
|
||
user.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
logger.info(f"用户修改密码: {user.email}")
|
||
return {"message": "密码修改成功"}
|
||
|
||
|
||
class ListMyApiKeysAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
|
||
# 一次性查询所有 API keys
|
||
api_keys = (
|
||
db.query(ApiKey)
|
||
.filter(ApiKey.user_id == user.id)
|
||
.order_by(ApiKey.created_at.desc())
|
||
.all()
|
||
)
|
||
|
||
if not api_keys:
|
||
return []
|
||
|
||
# 批量查询所有 API keys 的统计数据(单次查询)
|
||
api_key_ids = [key.id for key in api_keys]
|
||
stats_query = (
|
||
db.query(
|
||
Usage.api_key_id,
|
||
func.count(Usage.id).label("requests"),
|
||
func.sum(Usage.total_cost_usd).label("cost"),
|
||
func.max(Usage.created_at).label("last_used"),
|
||
)
|
||
.filter(Usage.api_key_id.in_(api_key_ids))
|
||
.group_by(Usage.api_key_id)
|
||
.all()
|
||
)
|
||
|
||
# 构建统计数据映射
|
||
stats_map = {
|
||
row.api_key_id: {
|
||
"total_requests": row.requests or 0,
|
||
"total_cost_usd": float(row.cost or 0),
|
||
"last_used_at": row.last_used,
|
||
}
|
||
for row in stats_query
|
||
}
|
||
|
||
result = []
|
||
for key in api_keys:
|
||
# 从映射中获取统计,没有则使用默认值
|
||
real_stats = stats_map.get(
|
||
key.id,
|
||
{"total_requests": 0, "total_cost_usd": 0.0, "last_used_at": None},
|
||
)
|
||
|
||
result.append(
|
||
{
|
||
"id": key.id,
|
||
"name": key.name,
|
||
"key_display": key.get_display_key(),
|
||
"is_active": key.is_active,
|
||
"last_used_at": (
|
||
real_stats["last_used_at"].isoformat()
|
||
if real_stats["last_used_at"]
|
||
else None
|
||
),
|
||
"created_at": key.created_at.isoformat(),
|
||
"total_requests": real_stats["total_requests"],
|
||
"total_cost_usd": real_stats["total_cost_usd"],
|
||
"allowed_providers": key.allowed_providers,
|
||
"force_capabilities": key.force_capabilities,
|
||
}
|
||
)
|
||
return result
|
||
|
||
|
||
class CreateMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
request = CreateMyApiKeyRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
try:
|
||
api_key, plain_key = ApiKeyService.create_api_key(
|
||
db=context.db,
|
||
user_id=context.user.id,
|
||
name=request.name,
|
||
)
|
||
except ValueError as exc:
|
||
raise InvalidRequestException(str(exc))
|
||
|
||
return {
|
||
"id": api_key.id,
|
||
"name": api_key.name,
|
||
"key": plain_key,
|
||
"key_display": api_key.get_display_key(),
|
||
"message": "API密钥创建成功",
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class GetMyFullKeyAdapter(AuthenticatedApiAdapter):
|
||
key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
|
||
# 查找API密钥,确保属于当前用户
|
||
api_key = (
|
||
db.query(ApiKey).filter(ApiKey.id == self.key_id, ApiKey.user_id == user.id).first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在", "api_key")
|
||
|
||
# 解密完整密钥
|
||
if not api_key.key_encrypted:
|
||
raise HTTPException(status_code=400, detail="该密钥没有存储完整密钥信息")
|
||
|
||
try:
|
||
full_key = crypto_service.decrypt(api_key.key_encrypted)
|
||
except Exception as e:
|
||
logger.error(f"解密API密钥失败: Key ID {self.key_id}, 错误: {e}")
|
||
raise HTTPException(status_code=500, detail="解密密钥失败")
|
||
|
||
logger.info(f"用户 {user.email} 查看完整API密钥: Key ID {self.key_id}")
|
||
|
||
return {
|
||
"key": full_key,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class GetMyApiKeyDetailAdapter(AuthenticatedApiAdapter):
|
||
"""Get API key detail without full key"""
|
||
key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
|
||
api_key = (
|
||
db.query(ApiKey).filter(ApiKey.id == self.key_id, ApiKey.user_id == user.id).first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在", "api_key")
|
||
|
||
return {
|
||
"id": api_key.id,
|
||
"name": api_key.name,
|
||
"key_display": api_key.get_display_key(),
|
||
"is_active": api_key.is_active,
|
||
"allowed_providers": api_key.allowed_providers,
|
||
"force_capabilities": api_key.force_capabilities,
|
||
"rate_limit": api_key.rate_limit,
|
||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||
"created_at": api_key.created_at.isoformat(),
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class DeleteMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||
key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
api_key = (
|
||
context.db.query(ApiKey)
|
||
.filter(ApiKey.id == self.key_id, ApiKey.user_id == context.user.id)
|
||
.first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在", "api_key")
|
||
context.db.delete(api_key)
|
||
context.db.commit()
|
||
return {"message": "API密钥已删除"}
|
||
|
||
|
||
@dataclass
|
||
class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||
key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
api_key = (
|
||
context.db.query(ApiKey)
|
||
.filter(ApiKey.id == self.key_id, ApiKey.user_id == context.user.id)
|
||
.first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在", "api_key")
|
||
api_key.is_active = not api_key.is_active
|
||
context.db.commit()
|
||
context.db.refresh(api_key)
|
||
return {
|
||
"id": api_key.id,
|
||
"is_active": api_key.is_active,
|
||
"message": f"API密钥已{'启用' if api_key.is_active else '禁用'}",
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class GetUsageAdapter(AuthenticatedApiAdapter):
|
||
start_date: Optional[datetime]
|
||
end_date: Optional[datetime]
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
summary_list = UsageService.get_usage_summary(
|
||
db=db,
|
||
user_id=user.id,
|
||
start_date=self.start_date,
|
||
end_date=self.end_date,
|
||
)
|
||
|
||
# 过滤掉 unknown/pending provider 的记录(请求未到达任何提供商)
|
||
filtered_summary = [
|
||
item for item in summary_list
|
||
if item.get("provider") not in ("unknown", "pending", None)
|
||
]
|
||
|
||
total_requests = sum(item["requests"] for item in filtered_summary)
|
||
total_input_tokens = (
|
||
sum(item["input_tokens"] for item in filtered_summary) if filtered_summary else 0
|
||
)
|
||
total_output_tokens = (
|
||
sum(item["output_tokens"] for item in filtered_summary) if filtered_summary else 0
|
||
)
|
||
total_tokens = sum(item["total_tokens"] for item in filtered_summary) if filtered_summary else 0
|
||
total_cost = sum(item["total_cost_usd"] for item in filtered_summary) if filtered_summary else 0.0
|
||
|
||
# 管理员可以看到真实成本
|
||
total_actual_cost = 0.0
|
||
if user.role == "admin":
|
||
total_actual_cost = (
|
||
sum(item.get("actual_total_cost_usd", 0.0) for item in filtered_summary)
|
||
if filtered_summary
|
||
else 0.0
|
||
)
|
||
|
||
model_summary = {}
|
||
for item in filtered_summary:
|
||
model_name = item["model"]
|
||
base_stats = {
|
||
"model": model_name,
|
||
"requests": 0,
|
||
"input_tokens": 0,
|
||
"output_tokens": 0,
|
||
"total_tokens": 0,
|
||
"total_cost_usd": 0.0,
|
||
}
|
||
# 管理员可以看到真实成本
|
||
if user.role == "admin":
|
||
base_stats["actual_total_cost_usd"] = 0.0
|
||
|
||
stats = model_summary.setdefault(model_name, base_stats)
|
||
stats["requests"] += item["requests"]
|
||
stats["input_tokens"] += item["input_tokens"]
|
||
stats["output_tokens"] += item["output_tokens"]
|
||
stats["total_tokens"] += item["total_tokens"]
|
||
stats["total_cost_usd"] += item["total_cost_usd"]
|
||
# 管理员可以看到真实成本
|
||
if user.role == "admin":
|
||
stats["actual_total_cost_usd"] += item.get("actual_total_cost_usd", 0.0)
|
||
|
||
summary_by_model = sorted(model_summary.values(), key=lambda x: x["requests"], reverse=True)
|
||
|
||
# 按提供商汇总(用于 UsageProviderTable)
|
||
provider_summary = {}
|
||
for item in filtered_summary:
|
||
provider_name = item["provider"]
|
||
base_stats = {
|
||
"provider": provider_name,
|
||
"requests": 0,
|
||
"total_tokens": 0,
|
||
"total_cost_usd": 0.0,
|
||
"success_count": 0,
|
||
"total_response_time_ms": 0.0,
|
||
"response_time_count": 0,
|
||
}
|
||
stats = provider_summary.setdefault(provider_name, base_stats)
|
||
stats["requests"] += item["requests"]
|
||
stats["total_tokens"] += item["total_tokens"]
|
||
stats["total_cost_usd"] += item["total_cost_usd"]
|
||
# 假设 summary 中的都是成功的请求
|
||
stats["success_count"] += item["requests"]
|
||
if item.get("avg_response_time_ms"):
|
||
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
|
||
stats["response_time_count"] += item["requests"]
|
||
|
||
summary_by_provider = []
|
||
for stats in provider_summary.values():
|
||
avg_response_time_ms = (
|
||
stats["total_response_time_ms"] / stats["response_time_count"]
|
||
if stats["response_time_count"] > 0 else 0
|
||
)
|
||
success_rate = (
|
||
(stats["success_count"] / stats["requests"] * 100)
|
||
if stats["requests"] > 0 else 100
|
||
)
|
||
summary_by_provider.append({
|
||
"provider": stats["provider"],
|
||
"requests": stats["requests"],
|
||
"total_tokens": stats["total_tokens"],
|
||
"total_cost_usd": stats["total_cost_usd"],
|
||
"success_rate": round(success_rate, 2),
|
||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||
})
|
||
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
|
||
|
||
query = db.query(Usage).filter(Usage.user_id == user.id)
|
||
if self.start_date:
|
||
query = query.filter(Usage.created_at >= self.start_date)
|
||
if self.end_date:
|
||
query = query.filter(Usage.created_at <= self.end_date)
|
||
usage_records = query.order_by(Usage.created_at.desc()).limit(100).all()
|
||
|
||
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
|
||
Usage.user_id == user.id,
|
||
Usage.status_code == 200,
|
||
Usage.response_time_ms.isnot(None),
|
||
)
|
||
if self.start_date:
|
||
avg_resp_query = avg_resp_query.filter(Usage.created_at >= self.start_date)
|
||
if self.end_date:
|
||
avg_resp_query = avg_resp_query.filter(Usage.created_at <= self.end_date)
|
||
avg_response_ms = avg_resp_query.scalar() or 0
|
||
avg_response_time = float(avg_response_ms) / 1000.0 if avg_response_ms else 0
|
||
|
||
# 构建响应数据
|
||
response_data = {
|
||
"total_requests": total_requests,
|
||
"total_input_tokens": total_input_tokens,
|
||
"total_output_tokens": total_output_tokens,
|
||
"total_tokens": total_tokens,
|
||
"total_cost": total_cost,
|
||
"avg_response_time": avg_response_time,
|
||
"quota_usd": user.quota_usd,
|
||
"used_usd": user.used_usd,
|
||
"summary_by_model": summary_by_model,
|
||
"summary_by_provider": summary_by_provider,
|
||
"records": [
|
||
{
|
||
"id": r.id,
|
||
"provider": r.provider,
|
||
"model": r.model,
|
||
"target_model": r.target_model, # 映射后的目标模型名
|
||
"api_format": r.api_format,
|
||
"input_tokens": r.input_tokens,
|
||
"output_tokens": r.output_tokens,
|
||
"total_tokens": r.total_tokens,
|
||
"cost": r.total_cost_usd,
|
||
"response_time_ms": r.response_time_ms,
|
||
"is_stream": r.is_stream,
|
||
"status": r.status, # 请求状态: pending, streaming, completed, failed
|
||
"created_at": r.created_at.isoformat(),
|
||
"cache_creation_input_tokens": r.cache_creation_input_tokens,
|
||
"cache_read_input_tokens": r.cache_read_input_tokens,
|
||
"status_code": r.status_code,
|
||
"error_message": r.error_message,
|
||
"input_price_per_1m": r.input_price_per_1m,
|
||
"output_price_per_1m": r.output_price_per_1m,
|
||
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
|
||
"cache_read_price_per_1m": r.cache_read_price_per_1m,
|
||
}
|
||
for r in usage_records
|
||
],
|
||
}
|
||
|
||
response_data["activity_heatmap"] = UsageService.get_daily_activity(
|
||
db=db,
|
||
user_id=user.id,
|
||
window_days=365,
|
||
include_actual_cost=user.role == "admin",
|
||
)
|
||
|
||
# 管理员可以看到真实成本
|
||
if user.role == "admin":
|
||
response_data["total_actual_cost"] = total_actual_cost
|
||
# 为每条记录添加真实成本和倍率信息
|
||
for i, r in enumerate(usage_records):
|
||
# 确保字段有值,避免前端显示 -
|
||
actual_cost = (
|
||
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
|
||
)
|
||
rate_mult = r.rate_multiplier if r.rate_multiplier is not None else 1.0
|
||
response_data["records"][i]["actual_cost"] = actual_cost
|
||
response_data["records"][i]["rate_multiplier"] = rate_mult
|
||
|
||
# 调试日志:检查前几条记录
|
||
if i < 3:
|
||
from src.core.logger import logger
|
||
logger.debug(
|
||
f"Usage record {i}: id={r.id}, actual_total_cost_usd={r.actual_total_cost_usd}, "
|
||
f"rate_multiplier={r.rate_multiplier}, returned: actual_cost={actual_cost}, rate_mult={rate_mult}"
|
||
)
|
||
|
||
return response_data
|
||
|
||
|
||
@dataclass
|
||
class GetActiveRequestsAdapter(AuthenticatedApiAdapter):
|
||
"""轻量级活跃请求状态查询适配器(用于用户端轮询)"""
|
||
|
||
ids: Optional[str] = None
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
from src.services.usage import UsageService
|
||
|
||
db = context.db
|
||
user = context.user
|
||
id_list = None
|
||
if self.ids:
|
||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||
if not id_list:
|
||
return {"requests": []}
|
||
|
||
requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id)
|
||
return {"requests": requests}
|
||
|
||
|
||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from src.models.database import Model, ModelMapping, ProviderEndpoint
|
||
|
||
db = context.db
|
||
|
||
# 使用 selectinload 预加载所有关联数据,避免 N+1 查询
|
||
providers = (
|
||
db.query(Provider)
|
||
.options(
|
||
selectinload(Provider.endpoints),
|
||
selectinload(Provider.models).selectinload(Model.global_model),
|
||
)
|
||
.filter(Provider.is_active.is_(True))
|
||
.all()
|
||
)
|
||
|
||
result = []
|
||
for provider in providers:
|
||
# 直接使用预加载的 endpoints,无需额外查询
|
||
endpoints_data = [
|
||
{
|
||
"id": ep.id,
|
||
"api_format": ep.api_format if ep.api_format else None,
|
||
"base_url": ep.base_url,
|
||
"is_active": ep.is_active,
|
||
}
|
||
for ep in provider.endpoints
|
||
]
|
||
|
||
models_data = []
|
||
# 直接使用预加载的 models,无需额外查询
|
||
direct_models = provider.models
|
||
for model in direct_models:
|
||
global_model = model.global_model
|
||
display_name = (
|
||
global_model.display_name if global_model else model.provider_model_name
|
||
)
|
||
unified_name = global_model.name if global_model else model.provider_model_name
|
||
models_data.append(
|
||
{
|
||
"id": model.id,
|
||
"name": unified_name,
|
||
"display_name": display_name,
|
||
"input_price_per_1m": model.input_price_per_1m,
|
||
"output_price_per_1m": model.output_price_per_1m,
|
||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
||
"supports_vision": model.supports_vision,
|
||
"supports_function_calling": model.supports_function_calling,
|
||
"supports_streaming": model.supports_streaming,
|
||
}
|
||
)
|
||
|
||
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
|
||
provider_model_global_ids = {
|
||
m.global_model_id for m in provider.models if m.global_model_id
|
||
}
|
||
if provider_model_global_ids:
|
||
# 查询全局别名 + Provider 特定映射
|
||
alias_mappings = (
|
||
db.query(ModelMapping)
|
||
.options(joinedload(ModelMapping.target_global_model))
|
||
.filter(
|
||
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
|
||
ModelMapping.is_active == True,
|
||
(ModelMapping.provider_id == provider.id)
|
||
| (ModelMapping.provider_id == None),
|
||
)
|
||
.all()
|
||
)
|
||
for alias_obj in alias_mappings:
|
||
# 为这个别名找到该 Provider 的 Model 实现
|
||
model = next(
|
||
(
|
||
m
|
||
for m in provider.models
|
||
if m.global_model_id == alias_obj.target_global_model_id
|
||
),
|
||
None,
|
||
)
|
||
if model:
|
||
models_data.append(
|
||
{
|
||
"id": alias_obj.id,
|
||
"name": alias_obj.source_model,
|
||
"display_name": (
|
||
alias_obj.target_global_model.display_name
|
||
if alias_obj.target_global_model
|
||
else alias_obj.source_model
|
||
),
|
||
"input_price_per_1m": model.input_price_per_1m,
|
||
"output_price_per_1m": model.output_price_per_1m,
|
||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
||
"supports_vision": model.supports_vision,
|
||
"supports_function_calling": model.supports_function_calling,
|
||
"supports_streaming": model.supports_streaming,
|
||
}
|
||
)
|
||
|
||
result.append(
|
||
{
|
||
"id": provider.id,
|
||
"name": provider.name,
|
||
"display_name": provider.display_name,
|
||
"description": provider.description,
|
||
"provider_priority": provider.provider_priority,
|
||
"endpoints": endpoints_data,
|
||
"models": models_data,
|
||
}
|
||
)
|
||
return result
|
||
|
||
|
||
@dataclass
|
||
class UpdateApiKeyProvidersAdapter(AuthenticatedApiAdapter):
|
||
api_key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
user = context.user
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
request = UpdateApiKeyProvidersRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
api_key = (
|
||
db.query(ApiKey).filter(ApiKey.id == self.api_key_id, ApiKey.user_id == user.id).first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在")
|
||
|
||
if request.allowed_providers is not None and len(request.allowed_providers) > 0:
|
||
provider_ids = [cfg.provider_id for cfg in request.allowed_providers]
|
||
valid = (
|
||
db.query(Provider.id)
|
||
.filter(Provider.id.in_(provider_ids), Provider.is_active.is_(True))
|
||
.all()
|
||
)
|
||
valid_ids = {p.id for p in valid}
|
||
invalid = set(provider_ids) - valid_ids
|
||
if invalid:
|
||
raise InvalidRequestException(f"无效的提供商ID: {', '.join(invalid)}")
|
||
|
||
# 只存储 provider_id 列表,而不是完整的 ProviderConfig 字典
|
||
# 因为 allowed_providers 字段设计为存储 provider ID 字符串列表
|
||
api_key.allowed_providers = (
|
||
[cfg.provider_id for cfg in request.allowed_providers]
|
||
if request.allowed_providers
|
||
else None
|
||
)
|
||
api_key.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
logger.debug(f"用户 {user.id} 更新API密钥 {self.api_key_id} 的可用提供商")
|
||
return {"message": "API密钥可用提供商已更新"}
|
||
|
||
|
||
@dataclass
|
||
class UpdateApiKeyCapabilitiesAdapter(AuthenticatedApiAdapter):
|
||
"""更新 API Key 的强制能力配置"""
|
||
|
||
api_key_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
from src.core.key_capabilities import CAPABILITY_DEFINITIONS, CapabilityConfigMode
|
||
from src.models.database import AuditEventType
|
||
from src.services.system.audit import audit_service
|
||
|
||
db = context.db
|
||
user = context.user
|
||
payload = context.ensure_json_body()
|
||
|
||
api_key = (
|
||
db.query(ApiKey).filter(ApiKey.id == self.api_key_id, ApiKey.user_id == user.id).first()
|
||
)
|
||
if not api_key:
|
||
raise NotFoundException("API密钥不存在")
|
||
|
||
# 保存旧值用于审计
|
||
old_capabilities = api_key.force_capabilities
|
||
|
||
# 验证 force_capabilities 字段
|
||
force_capabilities = payload.get("force_capabilities")
|
||
if force_capabilities is not None:
|
||
if not isinstance(force_capabilities, dict):
|
||
raise InvalidRequestException("force_capabilities 必须是对象类型")
|
||
|
||
# 验证只允许用户可配置的能力
|
||
for cap_name, cap_value in force_capabilities.items():
|
||
cap_def = CAPABILITY_DEFINITIONS.get(cap_name)
|
||
if not cap_def:
|
||
raise InvalidRequestException(f"未知的能力类型: {cap_name}")
|
||
if cap_def.config_mode != CapabilityConfigMode.USER_CONFIGURABLE:
|
||
raise InvalidRequestException(f"能力 {cap_name} 不支持用户配置")
|
||
if not isinstance(cap_value, bool):
|
||
raise InvalidRequestException(f"能力 {cap_name} 的值必须是布尔类型")
|
||
|
||
api_key.force_capabilities = force_capabilities
|
||
api_key.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
|
||
# 记录审计日志
|
||
audit_service.log_event(
|
||
db=db,
|
||
event_type=AuditEventType.CONFIG_CHANGED,
|
||
description=f"用户更新 API Key 能力配置",
|
||
user_id=user.id,
|
||
api_key_id=api_key.id,
|
||
metadata={
|
||
"action": "update_api_key_capabilities",
|
||
"old_capabilities": old_capabilities,
|
||
"new_capabilities": force_capabilities,
|
||
},
|
||
)
|
||
|
||
logger.debug(f"用户 {user.id} 更新API密钥 {self.api_key_id} 的强制能力配置: {force_capabilities}")
|
||
return {
|
||
"message": "API密钥能力配置已更新",
|
||
"force_capabilities": api_key.force_capabilities,
|
||
}
|
||
|
||
|
||
class GetPreferencesAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
preferences = PreferenceService.get_or_create_preferences(context.db, context.user.id)
|
||
return {
|
||
"avatar_url": preferences.avatar_url,
|
||
"bio": preferences.bio,
|
||
"default_provider_id": preferences.default_provider_id,
|
||
"default_provider": (
|
||
preferences.default_provider.name if preferences.default_provider else None
|
||
),
|
||
"theme": preferences.theme,
|
||
"language": preferences.language,
|
||
"timezone": preferences.timezone,
|
||
"notifications": {
|
||
"email": preferences.email_notifications,
|
||
"usage_alerts": preferences.usage_alerts,
|
||
"announcements": preferences.announcement_notifications,
|
||
},
|
||
}
|
||
|
||
|
||
class UpdatePreferencesAdapter(AuthenticatedApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
request = UpdatePreferencesRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
PreferenceService.update_preferences(
|
||
db=context.db,
|
||
user_id=context.user.id,
|
||
avatar_url=request.avatar_url,
|
||
bio=request.bio,
|
||
default_provider_id=request.default_provider_id,
|
||
theme=request.theme,
|
||
language=request.language,
|
||
timezone=request.timezone,
|
||
email_notifications=request.email_notifications,
|
||
usage_alerts=request.usage_alerts,
|
||
announcement_notifications=request.announcement_notifications,
|
||
)
|
||
return {"message": "偏好设置更新成功"}
|
||
|
||
|
||
class GetModelCapabilitySettingsAdapter(AuthenticatedApiAdapter):
|
||
"""获取用户的模型能力配置"""
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
user = context.user
|
||
return {
|
||
"model_capability_settings": user.model_capability_settings or {},
|
||
}
|
||
|
||
|
||
class UpdateModelCapabilitySettingsAdapter(AuthenticatedApiAdapter):
|
||
"""更新用户的模型能力配置"""
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
from src.core.key_capabilities import CAPABILITY_DEFINITIONS, CapabilityConfigMode
|
||
from src.models.database import AuditEventType
|
||
from src.services.cache.user_cache import UserCacheService
|
||
from src.services.system.audit import audit_service
|
||
|
||
db = context.db
|
||
# 重新从数据库查询用户,确保在 session 中(context.user 可能来自缓存,是分离对象)
|
||
user = db.query(User).filter(User.id == context.user.id).first()
|
||
if not user:
|
||
raise NotFoundException("用户不存在")
|
||
payload = context.ensure_json_body()
|
||
|
||
# 保存旧值用于审计
|
||
old_settings = user.model_capability_settings
|
||
|
||
# 验证 model_capability_settings 字段
|
||
settings = payload.get("model_capability_settings")
|
||
if settings is not None:
|
||
if not isinstance(settings, dict):
|
||
raise InvalidRequestException("model_capability_settings 必须是对象类型")
|
||
|
||
# 验证每个模型的能力配置
|
||
for model_name, capabilities in settings.items():
|
||
if not isinstance(model_name, str):
|
||
raise InvalidRequestException("模型名称必须是字符串")
|
||
if not isinstance(capabilities, dict):
|
||
raise InvalidRequestException(f"模型 {model_name} 的能力配置必须是对象类型")
|
||
|
||
# 验证只允许用户可配置的能力
|
||
for cap_name, cap_value in capabilities.items():
|
||
cap_def = CAPABILITY_DEFINITIONS.get(cap_name)
|
||
if not cap_def:
|
||
raise InvalidRequestException(f"未知的能力类型: {cap_name}")
|
||
if cap_def.config_mode != CapabilityConfigMode.USER_CONFIGURABLE:
|
||
raise InvalidRequestException(f"能力 {cap_name} 不支持用户配置")
|
||
if not isinstance(cap_value, bool):
|
||
raise InvalidRequestException(f"能力 {cap_name} 的值必须是布尔类型")
|
||
|
||
user.model_capability_settings = settings
|
||
user.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
|
||
# 清除用户缓存,确保下次读取时获取最新数据
|
||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||
|
||
# 记录审计日志
|
||
audit_service.log_event(
|
||
db=db,
|
||
event_type=AuditEventType.CONFIG_CHANGED,
|
||
description=f"用户更新模型能力配置",
|
||
user_id=user.id,
|
||
metadata={
|
||
"action": "update_model_capability_settings",
|
||
"old_settings": old_settings,
|
||
"new_settings": settings,
|
||
},
|
||
)
|
||
|
||
logger.debug(f"用户 {user.id} 更新模型能力配置: {settings}")
|
||
return {
|
||
"message": "模型能力配置已更新",
|
||
"model_capability_settings": user.model_capability_settings,
|
||
}
|
||
|
||
|
||
class GetEndpointStatusAdapter(AuthenticatedApiAdapter):
|
||
"""获取端点状态(简化版,不包含敏感信息)"""
|
||
|
||
# 类级别缓存实例(延迟初始化)
|
||
_cache_backend = None
|
||
_cache_ttl = 60 # 缓存60秒
|
||
|
||
@classmethod
|
||
async def _get_cache(cls):
|
||
"""获取缓存后端实例(懒加载)"""
|
||
if cls._cache_backend is None:
|
||
from src.services.cache.backend import get_cache_backend
|
||
|
||
cls._cache_backend = await get_cache_backend(
|
||
name="endpoint_status",
|
||
backend_type="auto",
|
||
ttl=cls._cache_ttl, # 使用 ttl 而不是 default_ttl
|
||
)
|
||
return cls._cache_backend
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
from src.services.health.endpoint import EndpointHealthService
|
||
|
||
db = context.db
|
||
|
||
# 尝试从缓存获取
|
||
cache = await self._get_cache()
|
||
cache_key = "endpoint_status:all"
|
||
|
||
try:
|
||
cached = await cache.get(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
except Exception:
|
||
pass # 缓存失败不影响正常流程
|
||
|
||
# 使用共享服务获取健康状态(普通用户视图)
|
||
result = EndpointHealthService.get_endpoint_health_by_format(
|
||
db=db,
|
||
lookback_hours=6,
|
||
include_admin_fields=False, # 不包含敏感的管理员字段
|
||
use_cache=True,
|
||
)
|
||
|
||
# 写入缓存
|
||
try:
|
||
await cache.set(cache_key, result, ttl=self._cache_ttl)
|
||
except Exception:
|
||
pass # 缓存失败不影响正常流程
|
||
|
||
return result
|