关于令牌桶
import functools
import time
import asyncio
from aioredis import StrictRedis
class RateLimit:
token_key = "RateLimit:token"
lock_key = "RateLimit:lock"
push_time_key = "RateLimit:push:time"
def __init__(self, token_per_second: int = 1, wait_token: int = 0):
"""
:param token_per_second: 每秒生成的token数目
:param wait_token: 如果没有token等待的时间
"""
self.token_per_second = token_per_second
self.wait_token = wait_token
self.redis = StrictRedis()
async def set_last_push_time(self, t: int):
"""设置上次push token的时间"""
await self.redis.set(self.push_time_key, t)
@property
async def last_push_time(self):
"""获取上次push token的时间"""
return int(await self.redis.get(self.push_time_key) or int(time.time() - 10))
@property
async def current_token_count(self):
"""当前token数量"""
return int(await self.redis.llen(self.token_key) or 0)
async def push_token(self):
# token的更新需要加锁
async with self.redis.lock(self.lock_key, blocking_timeout=1):
now = int(time.time())
token_count = (now - await self.last_push_time) * self.token_per_second
if token_count == 0:
return
current_token_count = await self.current_token_count
token_count = min(token_count + current_token_count, self.token_per_second)
tokens = [1 for _ in range(token_count - current_token_count)]
if not tokens:
return
await self.redis.rpush(self.token_key, *tokens)
await self.set_last_push_time(now)
async def do_get_token(self):
# 获取令牌
if await self.redis.lpop(self.token_key):
return
raise Exception("limited")
async def try_get_token(self):
# 应对突发流量,当前线程可等待其他线程放入令牌以供使用
# 比如当前没有令牌,等待的过程中,下个线程塞入令牌时当前线程获取之后可正常调用
now = time.time()
while time.time() - self.wait_token < now:
if await self.current_token_count > 0:
return await self.do_get_token()
await asyncio.sleep(0.1)
# block之后重新塞入token,避免"流量全在0-1秒,1秒过后没有请求则令牌不更新"情况出现
self.push_token()
if await self.current_token_count <= 0:
raise Exception("limited")
async def can_do(self):
if self.wait_token == 0:
return await self.do_get_token()
await self.try_get_token()
async def do(self):
await self.push_token()
await self.can_do()
async def __aenter__(self):
await self.do()
async def __aexit__(self, exc_type, exc_val, exc_tb):
...
def __call__(self, func: callable):
if not asyncio.iscoroutinefunction(func):
raise Exception("only support coroutine function")
@functools.wraps(func)
async def inner(*args, **kwargs):
await self.do()
return await func(*args, **kwargs)
return inner
if __name__ == '__main__':
limit = RateLimit(3, 5)
@limit
async def test_wrapper(u, i):
print(u, i)
async def test_with(u, i):
async with RateLimit(3, 5):
print(u, i)
async def main(u):
for i in range(3):
await test_wrapper(u, i)
# await test_with(u, i)
# 6次main函数执行(tasks中有6个),令牌只有三个,sleep之后下一次执行会塞入新的令牌
# 之前被阻塞的方法会消费sleep之后函数塞入的令牌,实现预消费(防范突发状况)
await asyncio.sleep(2)
loop = asyncio.get_event_loop()
tasks = [main(1), main(2), main(3), main(4), main(5), main(6)]
loop.run_until_complete(asyncio.wait(tasks))
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)