Files
Aether/src/services/user/apikey.py
fawney19 9dad194130 fix: 修复 API Key 访问限制字段无法清除的问题
- 统一前端创建和更新 API Key 时的空数组处理逻辑
- 后端创建和更新接口都支持空数组转 NULL(表示不限制)
- 开启自动刷新时立即刷新一次数据
2025-12-24 22:35:30 +08:00

413 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API密钥管理服务
"""
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.models.database import ApiKey, Usage, User
class ApiKeyService:
"""API密钥管理服务"""
@staticmethod
def create_api_key(
db: Session,
user_id: str, # UUID
name: Optional[str] = None,
allowed_providers: Optional[List[str]] = None,
allowed_api_formats: Optional[List[str]] = None,
allowed_models: Optional[List[str]] = None,
rate_limit: int = 100,
concurrent_limit: int = 5,
expire_days: Optional[int] = None,
initial_balance_usd: Optional[float] = None,
is_standalone: bool = False,
auto_delete_on_expiry: bool = False,
) -> tuple[ApiKey, str]:
"""创建新的API密钥返回密钥对象和明文密钥
Args:
db: 数据库会话
user_id: 用户ID
name: 密钥名称
allowed_providers: 允许的提供商列表
allowed_api_formats: 允许的 API 格式列表
allowed_models: 允许的模型列表
rate_limit: 速率限制
concurrent_limit: 并发限制
expire_days: 过期天数None = 永不过期
initial_balance_usd: 初始余额USD仅用于独立KeyNone = 无限制
is_standalone: 是否为独立余额Key仅管理员可创建
auto_delete_on_expiry: 过期后是否自动删除True=物理删除False=仅禁用)
"""
# 生成密钥
key = ApiKey.generate_key()
key_hash = ApiKey.hash_key(key)
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
# 计算过期时间
expires_at = None
if expire_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 空数组转为 None表示不限制
api_key = ApiKey(
user_id=user_id,
key_hash=key_hash,
key_encrypted=key_encrypted,
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
allowed_providers=allowed_providers or None,
allowed_api_formats=allowed_api_formats or None,
allowed_models=allowed_models or None,
rate_limit=rate_limit,
concurrent_limit=concurrent_limit,
expires_at=expires_at,
balance_used_usd=0.0,
current_balance_usd=initial_balance_usd, # 直接使用初始余额None = 无限制
is_standalone=is_standalone,
auto_delete_on_expiry=auto_delete_on_expiry,
is_active=True,
)
db.add(api_key)
db.commit()
db.refresh(api_key)
logger.info(f"创建API密钥: 用户ID {user_id}, 密钥名 {api_key.name}, "
f"独立Key={is_standalone}, 初始余额={initial_balance_usd}")
return api_key, key # 返回密钥对象和明文密钥
@staticmethod
def get_api_key(db: Session, key_id: str) -> Optional[ApiKey]: # UUID
"""获取API密钥"""
return db.query(ApiKey).filter(ApiKey.id == key_id).first()
@staticmethod
def get_api_key_by_key(db: Session, key: str) -> Optional[ApiKey]:
"""通过密钥字符串获取API密钥"""
key_hash = ApiKey.hash_key(key)
return db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
@staticmethod
def list_user_api_keys(
db: Session, user_id: str, is_active: Optional[bool] = None # UUID
) -> List[ApiKey]:
"""列出用户的所有API密钥不包括独立Key"""
query = db.query(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.is_standalone == False # 排除独立Key
)
if is_active is not None:
query = query.filter(ApiKey.is_active == is_active)
return query.order_by(ApiKey.created_at.desc()).all()
@staticmethod
def list_standalone_api_keys(db: Session, is_active: Optional[bool] = None) -> List[ApiKey]:
"""列出所有独立余额Key仅管理员可用"""
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
if is_active is not None:
query = query.filter(ApiKey.is_active == is_active)
return query.order_by(ApiKey.created_at.desc()).all()
@staticmethod
def update_api_key(db: Session, key_id: str, **kwargs) -> Optional[ApiKey]: # UUID
"""更新API密钥"""
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
return None
# 可更新的字段
updatable_fields = [
"name",
"allowed_providers",
"allowed_api_formats",
"allowed_models",
"rate_limit",
"concurrent_limit",
"is_active",
"expires_at",
"balance_limit_usd",
"auto_delete_on_expiry",
]
# 允许显式设置为空数组/None 的字段(空数组会转为 None表示"全部"
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
for field, value in kwargs.items():
if field not in updatable_fields:
continue
# 对于 nullable_list_fields空数组应该转为 None表示不限制
if field in nullable_list_fields:
if value is not None:
# 空数组转为 None表示允许全部
setattr(api_key, field, value if value else None)
elif value is not None:
setattr(api_key, field, value)
api_key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(api_key)
logger.debug(f"更新API密钥: ID {key_id}")
return api_key
@staticmethod
def delete_api_key(db: Session, key_id: str) -> bool: # UUID
"""删除API密钥禁用"""
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
return False
api_key.is_active = False
api_key.updated_at = datetime.now(timezone.utc)
db.commit()
logger.info(f"删除API密钥: ID {key_id}")
return True
@staticmethod
def get_remaining_balance(api_key: ApiKey) -> Optional[float]:
"""计算剩余余额仅用于独立Key
Returns:
剩余余额None 表示无限制或非独立Key
"""
if not api_key.is_standalone:
return None
if api_key.current_balance_usd is None:
return None
# 剩余余额 = 当前余额 - 已使用余额
remaining = api_key.current_balance_usd - (api_key.balance_used_usd or 0)
return max(0, remaining) # 不能为负数
@staticmethod
def check_balance(api_key: ApiKey) -> tuple[bool, Optional[float]]:
"""检查余额限制仅用于独立Key
Returns:
(is_allowed, remaining_balance): 是否允许请求剩余余额None表示无限制
"""
if not api_key.is_standalone:
# 非独立Key不检查余额
return True, None
# 使用新的预付费模式: current_balance_usd
if api_key.current_balance_usd is None:
# 无余额限制
return True, None
# 使用统一的余额计算方法
remaining = ApiKeyService.get_remaining_balance(api_key)
is_allowed = remaining > 0 if remaining is not None else True
if not is_allowed:
logger.warning(f"API密钥余额不足: Key ID {api_key.id}, " f"剩余余额 ${remaining:.4f}")
return is_allowed, remaining
@staticmethod
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
"""检查速率限制
Returns:
(is_allowed, remaining): 是否允许请求,剩余可用次数
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
"""
# 如果 rate_limit 为 None表示不限制
if api_key.rate_limit is None:
return True, -1 # -1 表示无限制
# 计算时间窗口
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
# 统计窗口内的请求数
request_count = (
db.query(func.count(Usage.id))
.filter(Usage.api_key_id == api_key.id, Usage.created_at >= window_start)
.scalar()
or 0
)
# 检查是否超限
is_allowed = request_count < api_key.rate_limit
if not is_allowed:
logger.warning(f"API密钥速率限制: Key ID {api_key.id}, 请求数 {request_count}/{api_key.rate_limit}")
return is_allowed, api_key.rate_limit - request_count
@staticmethod
def add_balance(db: Session, key_id: str, amount_usd: float) -> Optional[ApiKey]:
"""为独立余额Key调整余额
Args:
db: 数据库会话
key_id: API Key ID
amount_usd: 要调整的余额金额USD正数为增加负数为扣除
Returns:
更新后的API Key对象如果Key不存在或不是独立Key则返回None
"""
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
logger.warning(f"余额调整失败: Key ID {key_id} 不存在")
return None
if not api_key.is_standalone:
logger.warning(f"余额调整失败: Key ID {key_id} 不是独立余额Key")
return None
if amount_usd == 0:
logger.warning(f"余额调整失败: 调整金额不能为0当前值 ${amount_usd}")
return None
# 如果是扣除(负数),检查是否超过当前余额
if amount_usd < 0:
current = api_key.current_balance_usd or 0
if abs(amount_usd) > current:
logger.warning(f"余额扣除失败: 扣除金额 ${abs(amount_usd):.4f} 超过当前余额 ${current:.4f}")
return None
# 调整当前余额
if api_key.current_balance_usd is None:
api_key.current_balance_usd = amount_usd if amount_usd > 0 else 0
else:
api_key.current_balance_usd = max(0, api_key.current_balance_usd + amount_usd)
api_key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(api_key)
action = "增加" if amount_usd > 0 else "扣除"
logger.info(f"余额调整成功: Key ID {key_id}, {action} ${abs(amount_usd):.4f}, "
f"新余额 ${api_key.current_balance_usd:.4f}")
return api_key
@staticmethod
def cleanup_expired_keys(db: Session, auto_delete: bool = False) -> int:
"""清理过期的API密钥
Args:
db: 数据库会话
auto_delete: 全局默认行为True=物理删除False=仅禁用)
单个Key的 auto_delete_on_expiry 字段会覆盖此设置
Returns:
int: 清理的密钥数量
"""
now = datetime.now(timezone.utc)
expired_keys = (
db.query(ApiKey)
.filter(ApiKey.expires_at <= now, ApiKey.is_active == True) # 只处理仍然活跃的
.all()
)
count = 0
for api_key in expired_keys:
# 优先使用Key自身的auto_delete_on_expiry设置,否则使用全局设置
should_delete = (
api_key.auto_delete_on_expiry
if api_key.auto_delete_on_expiry is not None
else auto_delete
)
if should_delete:
# 物理删除Usage记录会保留因为是 SET NULL
db.delete(api_key)
logger.info(f"删除过期API密钥: ID {api_key.id}, 名称 {api_key.name}, "
f"过期时间 {api_key.expires_at}")
else:
# 仅禁用
api_key.is_active = False
api_key.updated_at = now
logger.info(f"禁用过期API密钥: ID {api_key.id}, 名称 {api_key.name}, "
f"过期时间 {api_key.expires_at}")
count += 1
if count > 0:
db.commit()
return count
@staticmethod
def get_api_key_stats(
db: Session,
key_id: str, # UUID
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> Dict[str, Any]:
"""获取API密钥使用统计"""
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
return {}
query = db.query(Usage).filter(Usage.api_key_id == key_id)
if start_date:
query = query.filter(Usage.created_at >= start_date)
if end_date:
query = query.filter(Usage.created_at <= end_date)
# 统计数据
stats = db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost_usd"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
).filter(Usage.api_key_id == key_id)
if start_date:
stats = stats.filter(Usage.created_at >= start_date)
if end_date:
stats = stats.filter(Usage.created_at <= end_date)
result = stats.first()
# 按天统计
daily_stats = db.query(
func.date(Usage.created_at).label("date"),
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost_usd"),
).filter(Usage.api_key_id == key_id)
if start_date:
daily_stats = daily_stats.filter(Usage.created_at >= start_date)
if end_date:
daily_stats = daily_stats.filter(Usage.created_at <= end_date)
daily_stats = daily_stats.group_by(func.date(Usage.created_at)).all()
return {
"key_id": key_id,
"key_name": api_key.name,
"total_requests": result.requests or 0,
"total_tokens": result.tokens or 0,
"total_cost_usd": float(result.cost_usd or 0),
"avg_response_time_ms": float(result.avg_response_time or 0),
"daily_stats": [
{
"date": stat.date.isoformat() if stat.date else None,
"requests": stat.requests,
"tokens": stat.tokens,
"cost_usd": float(stat.cost_usd),
}
for stat in daily_stats
],
}