Python令牌桶限流

Python令牌桶限流,第1张

关于令牌桶

import functools
import time
import asyncio

from aioredis import StrictRedis


class RateLimit:
    token_key = "RateLimit:token"
    lock_key = "RateLimit:lock"
    push_time_key = "RateLimit:push:time"

    def __init__(self, token_per_second: int = 1, wait_token: int = 0):
        """
        :param token_per_second: 每秒生成的token数目
        :param wait_token: 如果没有token等待的时间
        """
        self.token_per_second = token_per_second
        self.wait_token = wait_token
        self.redis = StrictRedis()

    async def set_last_push_time(self, t: int):
        """设置上次push token的时间"""
        await self.redis.set(self.push_time_key, t)

    @property
    async def last_push_time(self):
        """获取上次push token的时间"""
        return int(await self.redis.get(self.push_time_key) or int(time.time() - 10))

    @property
    async def current_token_count(self):
        """当前token数量"""
        return int(await self.redis.llen(self.token_key) or 0)

    async def push_token(self):
        # token的更新需要加锁
        async with self.redis.lock(self.lock_key, blocking_timeout=1):
            now = int(time.time())
            token_count = (now - await self.last_push_time) * self.token_per_second
            if token_count == 0:
                return
            current_token_count = await self.current_token_count
            token_count = min(token_count + current_token_count, self.token_per_second)
            tokens = [1 for _ in range(token_count - current_token_count)]
            if not tokens:
                return
            await self.redis.rpush(self.token_key, *tokens)
            await self.set_last_push_time(now)

    async def do_get_token(self):
        # 获取令牌
        if await self.redis.lpop(self.token_key):
            return
        raise Exception("limited")

    async def try_get_token(self):
        # 应对突发流量,当前线程可等待其他线程放入令牌以供使用
        # 比如当前没有令牌,等待的过程中,下个线程塞入令牌时当前线程获取之后可正常调用
        now = time.time()
        while time.time() - self.wait_token < now:
            if await self.current_token_count > 0:
                return await self.do_get_token()
            await asyncio.sleep(0.1)
        # block之后重新塞入token,避免"流量全在0-1秒,1秒过后没有请求则令牌不更新"情况出现
        self.push_token()
        if await self.current_token_count <= 0:
            raise Exception("limited")

    async def can_do(self):
        if self.wait_token == 0:
            return await self.do_get_token()
        await self.try_get_token()

    async def do(self):
        await self.push_token()
        await self.can_do()

    async def __aenter__(self):
        await self.do()

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        ...

    def __call__(self, func: callable):
        if not asyncio.iscoroutinefunction(func):
            raise Exception("only support coroutine function")
        @functools.wraps(func)
        async def inner(*args, **kwargs):
            await self.do()
            return await func(*args, **kwargs)
        return inner


if __name__ == '__main__':
    limit = RateLimit(3, 5)


    @limit
    async def test_wrapper(u, i):
        print(u, i)


    async def test_with(u, i):
        async with RateLimit(3, 5):
            print(u, i)


    async def main(u):
        for i in range(3):
            await test_wrapper(u, i)
            # await test_with(u, i)
            # 6次main函数执行(tasks中有6个),令牌只有三个,sleep之后下一次执行会塞入新的令牌
            # 之前被阻塞的方法会消费sleep之后函数塞入的令牌,实现预消费(防范突发状况)
            await asyncio.sleep(2)

    loop = asyncio.get_event_loop()
    tasks = [main(1), main(2), main(3), main(4), main(5), main(6)]
    loop.run_until_complete(asyncio.wait(tasks))


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/567929.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存