Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def connection_lost(self, exc: Exception | None) -> None:
self.connection_lost_cb(exc)


def get_running_loop() -> asyncio.AbstractEventLoop:
"""Get the running event loop, extracted for mocking purposes."""
return asyncio.get_running_loop()


class LocalChannel(Channel):
"""Simple RPC-style channel for communicating with a device over a local network.

Expand Down Expand Up @@ -179,7 +184,7 @@ async def connect(self) -> None:
if self._is_connected:
self._logger.debug("Unexpected call to connect when already connected")
return
loop = asyncio.get_running_loop()
loop = get_running_loop()
protocol = _LocalProtocol(self._data_received, self._connection_lost)
try:
self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/test_local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]:
loop = Mock()
loop.create_connection = AsyncMock(return_value=(mock_transport, Mock()))

with patch("asyncio.get_running_loop", return_value=loop):
with patch("roborock.devices.local_channel.get_running_loop", return_value=loop):
yield loop


Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/__snapshots__/test_mqtt_session.ambr
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# serializer version: 1
# name: test_session_e2e_publish_message
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt >]
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
00000020 6f 72 64 |ord|
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt >]
00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.|
00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .|
Expand All @@ -14,13 +14,13 @@
00000040 99 71 bf |.q.|
# ---
# name: test_session_e2e_receive_message
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt >]
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
00000020 6f 72 64 |ord|
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt <]
00000000 90 04 00 01 00 00 |......|
[mqtt >]
00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.|
Expand Down
12 changes: 8 additions & 4 deletions tests/fixtures/aiomqtt_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]:

async def poll_sockets(client: mqtt.Client) -> None:
"""Poll the mqtt client sockets in a loop to pick up new data."""
while True:
event_loop.call_soon_threadsafe(client.loop_read)
event_loop.call_soon_threadsafe(client.loop_write)
await asyncio.sleep(0.01)
try:
while True:
event_loop.call_soon_threadsafe(client.loop_read)
event_loop.call_soon_threadsafe(client.loop_write)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
pass

task: asyncio.Task[None] | None = None

Expand All @@ -52,6 +55,7 @@ def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
yield
if task:
task.cancel()
await task


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/local_async_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def start_handle_write(data: bytes) -> None:

return (mock_transport, protocol)

with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
with patch("roborock.devices.local_channel.get_running_loop") as mock_loop:
mock_loop.return_value.create_connection.side_effect = create_connection
yield

Expand Down
4 changes: 3 additions & 1 deletion tests/fixtures/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.handle_request = handle_request
self.response_queue = response_queue
self.log = log
self.client_connected = False

def pending(self) -> int:
"""Return the number of bytes in the response buffer."""
Expand All @@ -52,6 +53,7 @@ def handle_socket_recv(self, read_size: int) -> bytes:

def handle_socket_send(self, client_request: bytes) -> int:
"""Receive an incoming request from the client."""
self.client_connected = True
_LOGGER.debug("Request: 0x%s", client_request.hex())
self.log.add_log_entry("[mqtt >]", client_request)
if (response := self.handle_request(client_request)) is not None:
Expand All @@ -64,7 +66,7 @@ def handle_socket_send(self, client_request: bytes) -> int:

def push_response(self) -> None:
"""Push a response to the client."""
if not self.response_queue.empty():
if not self.response_queue.empty() and self.client_connected:
response = self.response_queue.get()
# Enqueue a response to be sent back to the client in the buffer.
# The buffer will be emptied when the client calls recv() on the socket
Expand Down
11 changes: 8 additions & 3 deletions tests/fixtures/pahomqtt_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common code for MQTT tests."""

import logging
import warnings
from collections.abc import Callable, Generator
from queue import Queue
from typing import Any
Expand Down Expand Up @@ -50,9 +51,12 @@ def handle_select(rlist: list, wlist: list, *args: Any) -> list:
@pytest.fixture(name="fake_mqtt_socket_handler")
def fake_mqtt_socket_handler_fixture(
mqtt_request_handler: MqttRequestHandler, mqtt_response_queue: Queue[bytes], log: CapturedRequestLog
) -> FakeMqttSocketHandler:
) -> Generator[FakeMqttSocketHandler, None, None]:
"""Fixture that creates a fake MQTT broker."""
return FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log)
socket_handler = FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log)
yield socket_handler
if len(socket_handler.response_buf.getvalue()) > 0:
warnings.warn("Some enqueued MQTT responses were not consumed during the test")


@pytest.fixture(name="mock_sock")
Expand All @@ -76,7 +80,8 @@ def response_queue_fixture() -> Generator[Queue[bytes], None, None]:
"""Fixture that provides a queue for enqueueing responses to be sent to the client under test."""
response_queue: Queue[bytes] = Queue()
yield response_queue
assert response_queue.empty(), "Not all fake responses were consumed"
if not response_queue.empty():
warnings.warn("Some enqueued MQTT responses were not consumed during the test")


@pytest.fixture(name="mqtt_request_handler")
Expand Down
4 changes: 0 additions & 4 deletions tests/mqtt/test_roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None
"""Test the MQTT session."""

push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected

Expand Down Expand Up @@ -528,8 +526,6 @@ def succeed_then_fail_unauthorized() -> Any:
# Don't produce messages, just exit and restart to reconnect
message_iterator.loop = False

push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))

session = await create_mqtt_session(params)
assert session.connected

Expand Down