Skip to content

Commit 35cccce

Browse files
committed
antiflood
1 parent 83cecf4 commit 35cccce

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ alt="vtffijuUg5Y" width="240" height="180" border="10" /></a>
5252
- [ ] Analytics
5353
- [ ] Rate limit
5454
- [ ] Inline buttons
55+
- [ ] Improved structure
56+
- [ ] Middlewares
5557

5658
## Changelog
5759
### 0.1.0

bot/__main__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from aiogram import Bot, types
66
from aiogram.dispatcher import Dispatcher
77

8+
from aiogram.contrib.fsm_storage.memory import MemoryStorage
9+
810
import asyncio
911
from aiogram.utils import executor
1012
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -19,11 +21,24 @@
1921
from bot.core.scheduler import scheduler_job
2022

2123
from bot.handlers import setup as handlers_setup
24+
from bot.middlewares.antiflood import ThrottlingMiddleware, rate_limit
2225

2326

2427
scheduler = AsyncIOScheduler()
2528

26-
dp = Dispatcher(bot, loop=asyncio.get_event_loop())
29+
# storage = RedisStorage2(db=2)
30+
31+
# storage = RedisStorage2(
32+
# host="redis:/localhost",
33+
# port=6379,
34+
# db=7,
35+
# )
36+
37+
dp = Dispatcher(
38+
bot,
39+
loop=loop,
40+
storage=MemoryStorage(),
41+
)
2742

2843

2944
async def startup(dp: Dispatcher):
@@ -63,6 +78,7 @@ async def startup(dp: Dispatcher):
6378

6479

6580
@dp.message_handler(commands=["ping"])
81+
@rate_limit(5, "ping")
6682
async def ping(message: types.Message):
6783
await message.answer("Pong!")
6884

@@ -73,6 +89,7 @@ async def time(message: types.Message):
7389

7490

7591
def main():
92+
dp.middleware.setup(ThrottlingMiddleware())
7693
scheduler.add_job(
7794
scheduler_job,
7895
"interval",

bot/middlewares/antiflood.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import asyncio
2+
3+
from aiogram import Dispatcher, types
4+
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
5+
from aiogram.dispatcher.handler import CancelHandler, current_handler
6+
from aiogram.dispatcher.middlewares import BaseMiddleware
7+
from aiogram.utils.exceptions import Throttled
8+
9+
10+
def rate_limit(limit: int, key=None):
11+
"""
12+
13+
Decorator for configuring rate limit and key in different functions.
14+
15+
16+
:param limit:
17+
18+
:param key:
19+
20+
:return:
21+
22+
"""
23+
24+
def decorator(func):
25+
setattr(func, "throttling_rate_limit", limit)
26+
27+
if key:
28+
setattr(func, "throttling_key", key)
29+
30+
return func
31+
32+
return decorator
33+
34+
35+
class ThrottlingMiddleware(BaseMiddleware):
36+
37+
"""
38+
39+
Simple middleware
40+
41+
"""
42+
43+
def __init__(self, limit=3, key_prefix="antiflood_"):
44+
self.rate_limit = limit
45+
46+
self.prefix = key_prefix
47+
48+
super(ThrottlingMiddleware, self).__init__()
49+
50+
async def on_process_message(self, message: types.Message, data: dict):
51+
"""
52+
53+
This handler is called when dispatcher receives a message
54+
55+
56+
:param message:
57+
58+
"""
59+
60+
# Get current handler
61+
62+
handler = current_handler.get()
63+
64+
# Get dispatcher from context
65+
66+
dispatcher = Dispatcher.get_current()
67+
68+
# If handler was configured, get rate limit and key from handler
69+
70+
if handler:
71+
limit = getattr(handler, "throttling_rate_limit", self.rate_limit)
72+
73+
key = getattr(
74+
handler, "throttling_key", f"{self.prefix}_{handler.__name__}"
75+
)
76+
77+
else:
78+
limit = self.rate_limit
79+
80+
key = f"{self.prefix}_message"
81+
82+
# Use Dispatcher.throttle method.
83+
84+
try:
85+
await dispatcher.throttle(key, rate=limit)
86+
87+
except Throttled as t:
88+
# Execute action
89+
90+
await self.message_throttled(message, t)
91+
92+
# Cancel current handler
93+
94+
raise CancelHandler()
95+
96+
async def message_throttled(self, message: types.Message, throttled: Throttled):
97+
"""
98+
99+
Notify user only on first exceed and notify about unlocking only on last exceed
100+
101+
102+
:param message:
103+
104+
:param throttled:
105+
106+
"""
107+
108+
handler = current_handler.get()
109+
110+
dispatcher = Dispatcher.get_current()
111+
112+
if handler:
113+
key = getattr(
114+
handler, "throttling_key", f"{self.prefix}_{handler.__name__}"
115+
)
116+
117+
else:
118+
key = f"{self.prefix}_message"
119+
120+
# Calculate how many time is left till the block ends
121+
122+
delta = throttled.rate - throttled.delta
123+
124+
# Prevent flooding
125+
126+
if throttled.exceeded_count <= 2:
127+
await message.reply("Too many requests! ")
128+
129+
# Sleep.
130+
131+
await asyncio.sleep(delta)
132+
133+
# Check lock status
134+
135+
thr = await dispatcher.check_key(key)
136+
137+
# If current message is not last with current key - do not send message
138+
139+
if thr.exceeded_count == throttled.exceeded_count:
140+
await message.reply("Unlocked.")

0 commit comments

Comments
 (0)