fix: 统一时区处理,确保所有 datetime 带时区信息

- token_bucket.py: get_reset_time 和 Redis 后端使用 timezone.utc
- sliding_window.py: get_reset_time 和 retry_after 计算使用 timezone.utc
- provider_strategy.py: dateutil.parser 解析后确保有时区信息
This commit is contained in:
fawney19
2026-01-05 02:23:24 +08:00
parent 523e27ba9a
commit dec681fea0
3 changed files with 19 additions and 12 deletions

View File

@@ -2,7 +2,7 @@
提供商策略管理 API 端点 提供商策略管理 API 端点
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
@@ -103,6 +103,9 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
if config.quota_last_reset_at: if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at) new_reset_at = parser.parse(config.quota_last_reset_at)
# 确保有时区信息,如果没有则假设为 UTC
if new_reset_at.tzinfo is None:
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
provider.quota_last_reset_at = new_reset_at provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量 # 自动同步该周期内的历史使用量
@@ -118,7 +121,11 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}") logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at: if config.quota_expires_at:
provider.quota_expires_at = parser.parse(config.quota_expires_at) expires_at = parser.parse(config.quota_expires_at)
# 确保有时区信息,如果没有则假设为 UTC
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
provider.quota_expires_at = expires_at
db.commit() db.commit()
db.refresh(provider) db.refresh(provider)
@@ -149,7 +156,7 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
if not provider: if not provider:
raise HTTPException(status_code=404, detail="Provider not found") raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now() - timedelta(hours=self.hours) since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
stats = ( stats = (
db.query(ProviderUsageTracking) db.query(ProviderUsageTracking)
.filter( .filter(

View File

@@ -21,7 +21,7 @@ WARNING: 多进程环境注意事项
import asyncio import asyncio
import time import time
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime, timezone
from typing import Any, Deque, Dict from typing import Any, Deque, Dict
from src.core.logger import logger from src.core.logger import logger
@@ -95,12 +95,12 @@ class SlidingWindow:
"""获取最早的重置时间""" """获取最早的重置时间"""
self._cleanup() self._cleanup()
if not self.requests: if not self.requests:
return datetime.now() return datetime.now(timezone.utc)
# 最早的请求将在window_size秒后过期 # 最早的请求将在window_size秒后过期
oldest_request = self.requests[0] oldest_request = self.requests[0]
reset_time = oldest_request + self.window_size reset_time = oldest_request + self.window_size
return datetime.fromtimestamp(reset_time) return datetime.fromtimestamp(reset_time, tz=timezone.utc)
class SlidingWindowStrategy(RateLimitStrategy): class SlidingWindowStrategy(RateLimitStrategy):
@@ -250,7 +250,7 @@ class SlidingWindowStrategy(RateLimitStrategy):
retry_after = None retry_after = None
if not allowed: if not allowed:
# 计算需要等待的时间(最早请求过期的时间) # 计算需要等待的时间(最早请求过期的时间)
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1 retry_after = int((reset_at - datetime.now(timezone.utc)).total_seconds()) + 1
return RateLimitResult( return RateLimitResult(
allowed=allowed, allowed=allowed,

View File

@@ -3,7 +3,7 @@
import asyncio import asyncio
import os import os
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync from ...clients.redis_client import get_redis_client_sync
@@ -63,11 +63,11 @@ class TokenBucket:
def get_reset_time(self) -> datetime: def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间""" """获取下次完全恢复的时间"""
if self.tokens >= self.capacity: if self.tokens >= self.capacity:
return datetime.now() return datetime.now(timezone.utc)
tokens_needed = self.capacity - self.tokens tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full) return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy): class TokenBucketStrategy(RateLimitStrategy):
@@ -370,7 +370,7 @@ class RedisTokenBucketBackend:
if tokens is None or last_refill is None: if tokens is None or last_refill is None:
remaining = capacity remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate) reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
else: else:
tokens_value = float(tokens) tokens_value = float(tokens)
last_refill_value = float(last_refill) last_refill_value = float(last_refill)
@@ -378,7 +378,7 @@ class RedisTokenBucketBackend:
tokens_value = min(capacity, tokens_value + delta * refill_rate) tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value) remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after) reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
allowed = remaining >= amount allowed = remaining >= amount
retry_after = None retry_after = None