|
1 | 1 | import asyncio |
2 | 2 | from datetime import datetime |
3 | | -from typing import Any, cast, Coroutine |
| 3 | +from types import TracebackType |
| 4 | +from typing import cast |
4 | 5 |
|
5 | 6 | from aiohttp import ClientSession, TraceConfig |
6 | 7 | from aiolimiter import AsyncLimiter |
7 | 8 |
|
8 | | -DEFAULT_REQUESTS_PER_TIME_PERIOD = 60 |
9 | | -DEFAULT_TIME_PERIOD_IN_SECS = 60 |
| 9 | +from omnipy.api.protocols.private.util import IsRateLimitingClientSession |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class RateLimitingClientSession(ClientSession): |
13 | 13 | """ |
14 | 14 | A ClientSession that limits the number of requests made per time period, allowing an initial |
15 | 15 | burst of requests to go through before rate limiting kicks in for the rest. |
16 | 16 | """ |
17 | | - def __init__(self, |
18 | | - requests_per_time_period: float = DEFAULT_REQUESTS_PER_TIME_PERIOD, |
19 | | - time_period_in_secs: float = DEFAULT_TIME_PERIOD_IN_SECS, |
20 | | - *args, |
| 17 | + def __init__(self, requests_per_time_period: float, time_period_in_secs: float, *args, |
21 | 18 | **kwargs) -> None: |
22 | 19 | trace_config = TraceConfig() |
23 | 20 | trace_config.on_request_start.append(self._limit_request) |
@@ -70,5 +67,13 @@ async def _limit_request(self, *args, **kwargs): |
70 | 67 | def requests_per_second(self) -> float: |
71 | 68 | return self._requests_per_time_period / self._time_period_in_secs |
72 | 69 |
|
73 | | - def __aenter__(self) -> 'Coroutine[Any, Any, RateLimitingClientSession]': |
74 | | - return cast('Coroutine[Any, Any, RateLimitingClientSession]', super().__aenter__()) |
| 70 | + async def __aenter__(self) -> IsRateLimitingClientSession: # type: ignore[override] |
| 71 | + return cast(IsRateLimitingClientSession, await super().__aenter__()) |
| 72 | + |
| 73 | + async def __aexit__( |
| 74 | + self, |
| 75 | + exc_type: type[BaseException] | None, |
| 76 | + exc_val: BaseException | None, |
| 77 | + exc_tb: TracebackType | None, |
| 78 | + ) -> None: |
| 79 | + await super().__aexit__(exc_type, exc_val, exc_tb) |
0 commit comments