diff --git a/stompman/subscription.py b/stompman/subscription.py index 4d826cdf..86ae5150 100644 --- a/stompman/subscription.py +++ b/stompman/subscription.py @@ -50,18 +50,24 @@ async def _run_handler(self, *, frame: MessageFrame) -> None: try: await self.handler(frame) except self.suppressed_exception_classes as exception: - if self._should_handle_ack_nack and self.id in self._active_subscriptions: + if ( + self._should_handle_ack_nack + and self.id in self._active_subscriptions + and (ack_id := frame.headers["ack"]) + ): await self._connection_manager.maybe_write_frame( - NackFrame( - headers={"id": frame.headers["message-id"], "subscription": frame.headers["subscription"]} - ) + NackFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]}) ) self.on_suppressed_exception(exception, frame) else: - if self._should_handle_ack_nack and self.id in self._active_subscriptions: + if ( + self._should_handle_ack_nack + and self.id in self._active_subscriptions + and (ack_id := frame.headers["ack"]) + ): await self._connection_manager.maybe_write_frame( AckFrame( - headers={"id": frame.headers["message-id"], "subscription": frame.headers["subscription"]}, + headers={"id": ack_id, "subscription": frame.headers["subscription"]}, ) ) diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 0867557a..87ab1e64 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -223,11 +223,11 @@ async def test_client_listen_unsubscribe_before_ack_or_nack( async def test_client_listen_ack_nack_sent( monkeypatch: pytest.MonkeyPatch, faker: faker.Faker, ack: AckMode, *, ok: bool ) -> None: - subscription_id, destination, message_id = faker.pystr(), faker.pystr(), faker.pystr() + subscription_id, destination, ack_id = faker.pystr(), faker.pystr(), faker.pystr() monkeypatch.setattr(stompman.subscription, "_make_subscription_id", mock.Mock(return_value=subscription_id)) message_frame = build_dataclass( - MessageFrame, headers={"destination": destination, "message-id": message_id, "subscription": subscription_id} + MessageFrame, headers={"destination": destination, "ack": ack_id, "subscription": subscription_id} ) connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([message_frame])) message_handler = mock.AsyncMock(side_effect=None if ok else SomeError) @@ -244,9 +244,9 @@ async def test_client_listen_ack_nack_sent( assert collected_frames == enrich_expected_frames( SubscribeFrame(headers={"id": subscription_id, "destination": destination, "ack": ack}), message_frame, - AckFrame(headers={"id": message_id, "subscription": subscription_id}) + AckFrame(headers={"id": ack_id, "subscription": subscription_id}) if ok - else NackFrame(headers={"id": message_id, "subscription": subscription_id}), + else NackFrame(headers={"id": ack_id, "subscription": subscription_id}), UnsubscribeFrame(headers={"id": subscription_id}), )