From 823ee75d0d480e17112a3cb5b9e0e5841028a310 Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Sat, 2 Nov 2024 11:49:20 +0300 Subject: [PATCH] Allow async `on_heartbeat` callback --- README.md | 2 +- stompman/client.py | 12 +++++-- tests/test_connection_lifespan.py | 54 +++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 60c62897..30cf7888 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ async with stompman.Client( # Handlers: on_error_frame=lambda error_frame: print(error_frame.body), - on_heartbeat=lambda: print("Server sent a heartbeat"), + on_heartbeat=lambda: print("Server sent a heartbeat"), # also can be async # SSL — can be either `None` (default), `True`, or `ssl.SSLContext' ssl=None, diff --git a/stompman/client.py b/stompman/client.py index f9f2f904..0e238332 100644 --- a/stompman/client.py +++ b/stompman/client.py @@ -1,9 +1,9 @@ import asyncio +import inspect from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass, field from functools import partial -import inspect from ssl import SSLContext from types import TracebackType from typing import ClassVar, Literal, Self @@ -31,7 +31,7 @@ class Client: servers: list[ConnectionParameters] = field(kw_only=False) on_error_frame: Callable[[ErrorFrame], None] | None = None - on_heartbeat: Callable[[], None] | None = None + on_heartbeat: Callable[[], None] | Callable[[], Awaitable[None]] | None = None heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000)) ssl: Literal[True] | SSLContext | None = None @@ -53,6 +53,7 @@ class Client: _heartbeat_task: asyncio.Task[None] = field(init=False) _listen_task: asyncio.Task[None] = field(init=False) _task_group: asyncio.TaskGroup = field(init=False) + _on_heartbeat_is_async: bool = field(init=False) def __post_init__(self) -> None: self._connection_manager = ConnectionManager( @@ -76,6 +77,7 @@ def __post_init__(self) -> None: write_retry_attempts=self.write_retry_attempts, ssl=self.ssl, ) + self._on_heartbeat_is_async = inspect.iscoroutinefunction(self.on_heartbeat) if self.on_heartbeat else False async def __aenter__(self) -> Self: self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup()) @@ -116,7 +118,11 @@ async def _listen_to_frames(self) -> None: if self.on_error_frame: self.on_error_frame(frame) case HeartbeatFrame(): - if self.on_heartbeat: + if self.on_heartbeat is None: + pass + elif self._on_heartbeat_is_async: + task_group.create_task(self.on_heartbeat()) # type: ignore[arg-type] + else: self.on_heartbeat() case ConnectedFrame() | ReceiptFrame(): pass diff --git a/tests/test_connection_lifespan.py b/tests/test_connection_lifespan.py index 4a8d1abe..ed79862f 100644 --- a/tests/test_connection_lifespan.py +++ b/tests/test_connection_lifespan.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import AsyncGenerator, Coroutine +from functools import partial from typing import Any from unittest import mock @@ -17,6 +18,7 @@ DisconnectFrame, ErrorFrame, FailedAllConnectAttemptsError, + HeartbeatFrame, ReceiptFrame, UnsupportedProtocolVersion, ) @@ -154,6 +156,58 @@ async def mock_sleep(delay: float) -> None: assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()] +async def test_client_on_heartbeat_none(monkeypatch: pytest.MonkeyPatch) -> None: + real_sleep = asyncio.sleep + monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0)) + connection_class, _ = create_spying_connection( + *get_read_frames_with_lifespan( + [build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)] + ) + ) + + async with EnrichedClient(connection_class=connection_class, on_heartbeat=None): + await real_sleep(0) + await real_sleep(0) + await real_sleep(0) + + +async def test_client_on_heartbeat_sync(monkeypatch: pytest.MonkeyPatch) -> None: + real_sleep = asyncio.sleep + monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0)) + connection_class, _ = create_spying_connection( + *get_read_frames_with_lifespan( + [build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)] + ) + ) + on_heartbeat_mock = mock.Mock() + + async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock): + await real_sleep(0) + await real_sleep(0) + await real_sleep(0) + + assert on_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()] + + +async def test_client_on_heartbeat_async(monkeypatch: pytest.MonkeyPatch) -> None: + real_sleep = asyncio.sleep + monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0)) + connection_class, _ = create_spying_connection( + *get_read_frames_with_lifespan( + [build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)] + ) + ) + on_heartbeat_mock = mock.AsyncMock() + + async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock): + await real_sleep(0) + await real_sleep(0) + await real_sleep(0) + + assert on_heartbeat_mock.await_count == 3 # noqa: PLR2004 + assert on_heartbeat_mock.mock_calls == [mock.call.__bool__(), mock.call(), mock.call(), mock.call()] + + def test_make_receipt_id(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.undo() stompman.connection_lifespan._make_receipt_id()