diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index a69606af..0a7abd18 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -45,6 +45,11 @@ def connection_lost(self, exc: Exception | None) -> None: self.connection_lost_cb(exc) +def get_running_loop() -> asyncio.AbstractEventLoop: + """Get the running event loop, extracted for mocking purposes.""" + return asyncio.get_running_loop() + + class LocalChannel(Channel): """Simple RPC-style channel for communicating with a device over a local network. @@ -179,7 +184,7 @@ async def connect(self) -> None: if self._is_connected: self._logger.debug("Unexpected call to connect when already connected") return - loop = asyncio.get_running_loop() + loop = get_running_loop() protocol = _LocalProtocol(self._data_received, self._connection_lost) try: self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT) diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index fcfc4a0c..1cc6f842 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -52,7 +52,7 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]: loop = Mock() loop.create_connection = AsyncMock(return_value=(mock_transport, Mock())) - with patch("asyncio.get_running_loop", return_value=loop): + with patch("roborock.devices.local_channel.get_running_loop", return_value=loop): yield loop diff --git a/tests/e2e/__snapshots__/test_mqtt_session.ambr b/tests/e2e/__snapshots__/test_mqtt_session.ambr index aa23b772..533a573d 100644 --- a/tests/e2e/__snapshots__/test_mqtt_session.ambr +++ b/tests/e2e/__snapshots__/test_mqtt_session.ambr @@ -1,11 +1,11 @@ # serializer version: 1 # name: test_session_e2e_publish_message - [mqtt <] - 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| [mqtt >] 00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....| 00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw| 00000020 6f 72 64 |ord| + [mqtt <] + 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| [mqtt >] 00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.| 00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .| @@ -14,13 +14,13 @@ 00000040 99 71 bf |.q.| # --- # name: test_session_e2e_receive_message - [mqtt <] - 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| [mqtt >] 00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....| 00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw| 00000020 6f 72 64 |ord| [mqtt <] + 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| + [mqtt <] 00000000 90 04 00 01 00 00 |......| [mqtt >] 00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.| diff --git a/tests/fixtures/aiomqtt_fixtures.py b/tests/fixtures/aiomqtt_fixtures.py index d9e10e74..ac508f21 100644 --- a/tests/fixtures/aiomqtt_fixtures.py +++ b/tests/fixtures/aiomqtt_fixtures.py @@ -28,10 +28,13 @@ async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]: async def poll_sockets(client: mqtt.Client) -> None: """Poll the mqtt client sockets in a loop to pick up new data.""" - while True: - event_loop.call_soon_threadsafe(client.loop_read) - event_loop.call_soon_threadsafe(client.loop_write) - await asyncio.sleep(0.01) + try: + while True: + event_loop.call_soon_threadsafe(client.loop_read) + event_loop.call_soon_threadsafe(client.loop_write) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + pass task: asyncio.Task[None] | None = None @@ -52,6 +55,7 @@ def new_client(*args: Any, **kwargs: Any) -> mqtt.Client: yield if task: task.cancel() + await task @pytest.fixture diff --git a/tests/fixtures/local_async_fixtures.py b/tests/fixtures/local_async_fixtures.py index d804df82..e328d663 100644 --- a/tests/fixtures/local_async_fixtures.py +++ b/tests/fixtures/local_async_fixtures.py @@ -79,7 +79,7 @@ def start_handle_write(data: bytes) -> None: return (mock_transport, protocol) - with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: + with patch("roborock.devices.local_channel.get_running_loop") as mock_loop: mock_loop.return_value.create_connection.side_effect = create_connection yield diff --git a/tests/fixtures/mqtt.py b/tests/fixtures/mqtt.py index 08489de3..a765da5b 100644 --- a/tests/fixtures/mqtt.py +++ b/tests/fixtures/mqtt.py @@ -32,6 +32,7 @@ def __init__( self.handle_request = handle_request self.response_queue = response_queue self.log = log + self.client_connected = False def pending(self) -> int: """Return the number of bytes in the response buffer.""" @@ -52,6 +53,7 @@ def handle_socket_recv(self, read_size: int) -> bytes: def handle_socket_send(self, client_request: bytes) -> int: """Receive an incoming request from the client.""" + self.client_connected = True _LOGGER.debug("Request: 0x%s", client_request.hex()) self.log.add_log_entry("[mqtt >]", client_request) if (response := self.handle_request(client_request)) is not None: @@ -64,7 +66,7 @@ def handle_socket_send(self, client_request: bytes) -> int: def push_response(self) -> None: """Push a response to the client.""" - if not self.response_queue.empty(): + if not self.response_queue.empty() and self.client_connected: response = self.response_queue.get() # Enqueue a response to be sent back to the client in the buffer. # The buffer will be emptied when the client calls recv() on the socket diff --git a/tests/fixtures/pahomqtt_fixtures.py b/tests/fixtures/pahomqtt_fixtures.py index 97655f3d..ecdfe69b 100644 --- a/tests/fixtures/pahomqtt_fixtures.py +++ b/tests/fixtures/pahomqtt_fixtures.py @@ -1,6 +1,7 @@ """Common code for MQTT tests.""" import logging +import warnings from collections.abc import Callable, Generator from queue import Queue from typing import Any @@ -50,9 +51,12 @@ def handle_select(rlist: list, wlist: list, *args: Any) -> list: @pytest.fixture(name="fake_mqtt_socket_handler") def fake_mqtt_socket_handler_fixture( mqtt_request_handler: MqttRequestHandler, mqtt_response_queue: Queue[bytes], log: CapturedRequestLog -) -> FakeMqttSocketHandler: +) -> Generator[FakeMqttSocketHandler, None, None]: """Fixture that creates a fake MQTT broker.""" - return FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log) + socket_handler = FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log) + yield socket_handler + if len(socket_handler.response_buf.getvalue()) > 0: + warnings.warn("Some enqueued MQTT responses were not consumed during the test") @pytest.fixture(name="mock_sock") @@ -76,7 +80,8 @@ def response_queue_fixture() -> Generator[Queue[bytes], None, None]: """Fixture that provides a queue for enqueueing responses to be sent to the client under test.""" response_queue: Queue[bytes] = Queue() yield response_queue - assert response_queue.empty(), "Not all fake responses were consumed" + if not response_queue.empty(): + warnings.warn("Some enqueued MQTT responses were not consumed during the test") @pytest.fixture(name="mqtt_request_handler") diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 15526b66..17ca9c78 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -151,8 +151,6 @@ async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None """Test the MQTT session.""" push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) - push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) - push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected @@ -528,8 +526,6 @@ def succeed_then_fail_unauthorized() -> Any: # Don't produce messages, just exit and restart to reconnect message_iterator.loop = False - push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) - session = await create_mqtt_session(params) assert session.connected