Skip to content

Commit a15e1f4

Browse files
authored
Introduce max_concurrent_consumed_messages: stop reading when threshold is reached on slow consumer callbacks (#135)
1 parent 7a9092a commit a15e1f4

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ async with stompman.Client(
4040
disconnect_confirmation_timeout=2,
4141
write_retry_attempts=3,
4242
check_server_alive_interval_factor=3,
43+
max_concurrent_consumed_messages=10,
4344
) as client:
4445
...
4546
```

packages/stompman/stompman/client.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ class Client:
4242
disconnect_confirmation_timeout: int = 2
4343
check_server_alive_interval_factor: int = 3
4444
"""Client will check if server alive `server heartbeat interval` times `interval factor`"""
45+
max_concurrent_consumed_messages: int = 10
4546

4647
connection_class: type[AbstractConnection] = Connection
4748

4849
_connection_manager: ConnectionManager = field(init=False)
49-
_active_subscriptions: ActiveSubscriptions = field(default_factory=ActiveSubscriptions, init=False)
50+
_active_subscriptions: ActiveSubscriptions = field(default_factory=ActiveSubscriptions, init=False, repr=False)
5051
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
5152
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
52-
_listen_task: asyncio.Task[None] = field(init=False)
53-
_task_group: asyncio.TaskGroup = field(init=False)
53+
_listen_task: asyncio.Task[None] = field(init=False, repr=False)
54+
_task_group: asyncio.TaskGroup = field(init=False, repr=False)
55+
_message_frame_semaphore: asyncio.Semaphore = field(init=False, repr=False)
5456

5557
def __post_init__(self) -> None:
5658
self._connection_manager = ConnectionManager(
@@ -73,6 +75,7 @@ def __post_init__(self) -> None:
7375
check_server_alive_interval_factor=self.check_server_alive_interval_factor,
7476
ssl=self.ssl,
7577
)
78+
self._message_frame_semaphore = asyncio.Semaphore(self.max_concurrent_consumed_messages)
7679

7780
async def __aenter__(self) -> Self:
7881
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
@@ -96,16 +99,17 @@ async def _listen_to_frames(self) -> None:
9699
async for frame in self._connection_manager.read_frames_reconnecting():
97100
match frame:
98101
case MessageFrame():
99-
if subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"]):
100-
task_group.create_task(
101-
subscription._run_handler(frame=frame) # noqa: SLF001
102-
if isinstance(subscription, AutoAckSubscription)
103-
else subscription.handler(
104-
AckableMessageFrame(
105-
headers=frame.headers, body=frame.body, _subscription=subscription
106-
)
107-
)
102+
if not (subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"])):
103+
continue
104+
await self._message_frame_semaphore.acquire()
105+
created_task = task_group.create_task(
106+
subscription._run_handler(frame=frame) # noqa: SLF001
107+
if isinstance(subscription, AutoAckSubscription)
108+
else subscription.handler(
109+
AckableMessageFrame(headers=frame.headers, body=frame.body, _subscription=subscription)
108110
)
111+
)
112+
created_task.add_done_callback(lambda _: self._message_frame_semaphore.release())
109113
case ErrorFrame():
110114
if self.on_error_frame:
111115
self.on_error_frame(frame)

0 commit comments

Comments
 (0)