Skip to content

Commit 57ef59d

Browse files
committed
Added IsRateLimitingClientSession protocol. Removed default values for RateLimitingClientSession. Cleanup
1 parent a0a31ab commit 57ef59d

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

src/omnipy/api/protocols/private/util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import TracebackType
12
from typing import Callable, Protocol, runtime_checkable, TypeVar
23

34
from omnipy.api.typedefs import DecoratorClassT
@@ -96,3 +97,24 @@ def take_snapshot_teardown(self) -> None:
9697

9798
def take_snapshot(self, obj: _HasContentsT) -> None:
9899
...
100+
101+
102+
class IsRateLimitingClientSession(Protocol):
103+
""""""
104+
@property
105+
def requests_per_second(self) -> float:
106+
...
107+
108+
async def __aenter__(self) -> 'IsRateLimitingClientSession':
109+
...
110+
111+
async def __aexit__(
112+
self,
113+
exc_type: type[BaseException] | None,
114+
exc_val: BaseException | None,
115+
exc_tb: TracebackType | None,
116+
) -> None:
117+
...
118+
119+
async def close(self) -> None:
120+
...

src/omnipy/modules/remote/helpers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
import asyncio
22
from datetime import datetime
3-
from typing import Any, cast, Coroutine
3+
from types import TracebackType
4+
from typing import cast
45

56
from aiohttp import ClientSession, TraceConfig
67
from aiolimiter import AsyncLimiter
78

8-
DEFAULT_REQUESTS_PER_TIME_PERIOD = 60
9-
DEFAULT_TIME_PERIOD_IN_SECS = 60
9+
from omnipy.api.protocols.private.util import IsRateLimitingClientSession
1010

1111

1212
class RateLimitingClientSession(ClientSession):
1313
"""
1414
A ClientSession that limits the number of requests made per time period, allowing an initial
1515
burst of requests to go through before rate limiting kicks in for the rest.
1616
"""
17-
def __init__(self,
18-
requests_per_time_period: float = DEFAULT_REQUESTS_PER_TIME_PERIOD,
19-
time_period_in_secs: float = DEFAULT_TIME_PERIOD_IN_SECS,
20-
*args,
17+
def __init__(self, requests_per_time_period: float, time_period_in_secs: float, *args,
2118
**kwargs) -> None:
2219
trace_config = TraceConfig()
2320
trace_config.on_request_start.append(self._limit_request)
@@ -70,5 +67,13 @@ async def _limit_request(self, *args, **kwargs):
7067
def requests_per_second(self) -> float:
7168
return self._requests_per_time_period / self._time_period_in_secs
7269

73-
def __aenter__(self) -> 'Coroutine[Any, Any, RateLimitingClientSession]':
74-
return cast('Coroutine[Any, Any, RateLimitingClientSession]', super().__aenter__())
70+
async def __aenter__(self) -> IsRateLimitingClientSession: # type: ignore[override]
71+
return cast(IsRateLimitingClientSession, await super().__aenter__())
72+
73+
async def __aexit__(
74+
self,
75+
exc_type: type[BaseException] | None,
76+
exc_val: BaseException | None,
77+
exc_tb: TracebackType | None,
78+
) -> None:
79+
await super().__aexit__(exc_type, exc_val, exc_tb)
File renamed without changes.

0 commit comments

Comments
 (0)