Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def _l01_iv(timestamp: int, nonce: int, sequence: int) -> bytes:
return digest[:12]

@staticmethod
def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int) -> bytes:
def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int | None = None) -> bytes:
"""Derive AAD for L01 protocol."""
return (
sequence.to_bytes(4, "big")
+ connect_nonce.to_bytes(4, "big")
+ ack_nonce.to_bytes(4, "big")
+ (ack_nonce.to_bytes(4, "big") if ack_nonce is not None else b"")
+ nonce.to_bytes(4, "big")
+ timestamp.to_bytes(4, "big")
)
Expand All @@ -181,7 +181,7 @@ def encrypt_gcm_l01(
sequence: int,
nonce: int,
connect_nonce: int,
ack_nonce: int,
ack_nonce: int | None = None,
) -> bytes:
"""Encrypt plaintext for L01 protocol using AES-256-GCM."""
if not isinstance(plaintext, bytes):
Expand Down
54 changes: 54 additions & 0 deletions tests/e2e/__snapshots__/test_local_session.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# serializer version: 1
# name: test_connect
[local >]
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
[local <]
00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h|
00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH|
00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;|
# ---
# name: test_l01_session
[local >]
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
[local <]
00000000 00 |.|
[local >]
00000000 00 00 00 15 4c 30 31 00 00 00 01 00 00 23 82 68 |....L01......#.h|
00000010 a6 a2 25 00 00 ee 2f 30 e8 |..%.../0.|
[local <]
00000000 00 00 00 29 4c 30 31 00 00 00 01 00 00 00 17 68 |...)L01........h|
00000010 a6 a2 23 00 01 00 12 a0 4a ec 75 88 03 75 0f d2 |..#.....J.u..u..|
00000020 40 33 69 02 f4 71 50 72 f3 81 56 80 f4 |@3i..qPr..V..|
[local >]
00000000 00 00 00 3e 4c 30 31 00 00 00 7b 00 00 23 83 68 |...>L01...{..#.h|
00000010 a6 a2 26 00 65 00 27 9e fd c2 42 b7 01 b4 eb 9c |..&.e.'...B.....|
00000020 00 84 4f fd 51 1f bc a5 65 12 c2 dc 45 0e 21 cb |..O.Q...e...E.!.|
00000030 45 dc bb 0a ba 16 84 28 a7 33 e5 e2 fa a8 f1 f2 |E......(.3......|
00000040 ec f4 |..|
[local <]
00000000 00 00 00 37 4c 30 31 00 00 00 7b 00 00 00 17 68 |...7L01...{....h|
00000010 a6 a2 27 00 66 00 20 b7 72 49 8a 64 eb 16 a5 71 |..'.f. .rI.d...q|
00000020 73 eb 9e 7e 37 64 3e 75 c0 70 ea 39 4e de 82 1f |s..~7d>u.p.9N...|
00000030 e2 29 86 de 4a 7b 38 20 55 12 8a |.)..J{8 U..|
# ---
# name: test_send_command
[local >]
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
[local <]
00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h|
00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH|
00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;|
[local >]
00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 23 83 68 |...71.0...{..#.h|
00000010 a6 a2 25 00 65 00 20 91 5b 1f 43 34 d5 22 47 9f |..%.e. .[.C4."G.|
00000020 59 4e 45 53 85 f9 c6 6e f2 eb 27 eb 6d 03 d8 92 |YNES...n..'.m...|
00000030 5b 30 83 b4 a4 ea f5 85 be 38 57 |[0.......8W|
[local <]
00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 00 17 68 |...71.0...{....h|
00000010 a6 a2 26 00 66 00 20 07 8b 28 60 a8 08 18 12 47 |..&.f. ..(`....G|
00000020 05 20 3e f5 53 e3 fd 4a cc 03 72 7b b4 2c d9 84 |. >.S..J..r{.,..|
00000030 7f 4b 18 d8 76 7d 5c 65 87 7c 2d |.K..v}\e.|-|
# ---
32 changes: 32 additions & 0 deletions tests/e2e/__snapshots__/test_mqtt_session.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# serializer version: 1
# name: test_session_e2e_publish_message
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt >]
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
00000020 6f 72 64 |ord|
[mqtt >]
00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.|
00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .|
00000020 22 f1 91 1a 6e 89 71 ca ec 2d 44 2a 16 57 e7 5b |"...n.q..-D*.W.[|
00000030 4a 9a c8 97 4b 13 37 3b f5 81 13 45 7c e7 48 03 |J...K.7;...E|.H.|
00000040 99 71 bf |.q.|
# ---
# name: test_session_e2e_receive_message
[mqtt <]
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
[mqtt >]
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
00000020 6f 72 64 |ord|
[mqtt <]
00000000 90 04 00 01 00 00 |......|
[mqtt >]
00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.|
[mqtt <]
00000000 30 31 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |01..topic-1.1.0.|
00000010 00 00 7b 00 00 23 82 68 a6 a2 23 00 66 00 10 45 |..{..#.h..#.f..E|
00000020 3b c3 2b 12 a6 77 d9 55 f6 e0 89 f5 93 a5 30 5d |;.+..w.U......0]|
00000030 a0 72 fa |.r.|
# ---
185 changes: 156 additions & 29 deletions tests/e2e/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,59 @@

import asyncio
from collections.abc import AsyncGenerator
from unittest.mock import patch

import pytest
import syrupy

from roborock.devices.local_channel import LocalChannel
from roborock.protocol import create_local_decoder, create_local_encoder
from roborock.protocol import MessageParser, create_local_decoder
from roborock.protocols.v1_protocol import LocalProtocolVersion
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests.fixtures.logging import CapturedRequestLog
from tests.fixtures.mqtt import Subscriber
from tests.mock_data import LOCAL_KEY

TEST_HOST = "192.168.1.100"
TEST_DEVICE_UID = "test_device_uid"
TEST_CONNECT_NONCE = 12345
TEST_ACK_NONCE = 67890
TEST_RANDOM = 13579
TEST_RANDOM = 23


@pytest.fixture(name="local_channel")
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
with patch(
"roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID
):
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
yield channel
channel.close()
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
yield channel
channel.close()


def build_response(
def build_raw_response(
protocol: RoborockMessageProtocol,
seq: int,
payload: bytes,
random: int,
version: LocalProtocolVersion = LocalProtocolVersion.V1,
connect_nonce: int | None = None,
ack_nonce: int | None = None,
) -> bytes:
"""Build an encoded response message."""
if protocol == RoborockMessageProtocol.HELLO_RESPONSE:
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None)
else:
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)

msg = RoborockMessage(
message = RoborockMessage(
protocol=protocol,
random=random,
random=23,
seq=seq,
payload=payload,
version=version.value.encode(),
)
return encoder(msg)
return MessageParser.build(message, local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=ack_nonce)


async def test_connect(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test connecting to the device."""
# Queue HELLO response with payload to ensure it can be parsed
local_response_queue.put_nowait(
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
)
local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))

await local_channel.connect()

Expand All @@ -76,17 +72,19 @@ async def test_connect(
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

assert snapshot == log


async def test_send_command(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test sending a command."""
# Queue HELLO response
local_response_queue.put_nowait(
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
)
local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))

await local_channel.connect()

Expand All @@ -101,16 +99,145 @@ async def test_send_command(
seq=cmd_seq,
payload=b'{"method":"get_status"}',
)
# Prepare a fake response to the command.
local_response_queue.put_nowait(
build_raw_response(RoborockMessageProtocol.RPC_RESPONSE, cmd_seq, payload=b'{"status": "ok"}')
)

subscriber = Subscriber()
unsub = await local_channel.subscribe(subscriber.append)

await local_channel.publish(msg)

# Verify request
# Verify request received by the server
request_bytes = await local_received_requests.get()
assert local_received_requests.empty()

# Decode request
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
decoder = create_local_decoder(local_key=LOCAL_KEY)
msgs = list(decoder(request_bytes))
assert len(msgs) == 1
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
assert msgs[0].payload == b'{"method":"get_status"}'
assert msgs[0].version == LocalProtocolVersion.V1.value.encode()

# Verify response received by subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
response_message = subscriber.messages[0]
assert isinstance(response_message, RoborockMessage)
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
assert response_message.payload == b'{"status": "ok"}'

unsub()

assert snapshot == log


async def test_l01_session(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test connecting to a device that speaks the L01 protocol.

Note that this test currently has a delay because the actual local client
will delay before retrying with L01 after a failed 1.0 attempt. This should
also be improved in the actual client itself, but likely requires a closer
look at the actual device response in that scenario or moving to a serial
request/response behavior rather than publish/subscribe.
"""
# Client first attempts 1.0 and we reply with a fake invalid response. The
# response is arbitrary, and this could be improved by capturing a real L01
# device response to a 1.0 message.
local_response_queue.put_nowait(b"\x00")
# The client attempts L01 protocol as a followup. The connect nonce uses
# a deterministic number from deterministic_message_fixtures.
connect_nonce = 9090
local_response_queue.put_nowait(
build_raw_response(
RoborockMessageProtocol.HELLO_RESPONSE,
1,
payload=b"ok",
version=LocalProtocolVersion.L01,
connect_nonce=connect_nonce,
ack_nonce=None,
)
)

await local_channel.connect()

assert local_channel.is_connected

# Verify 1.0 HELLO request
request_bytes = await local_received_requests.get()
# Protocol is at offset 19 (2 bytes)
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
assert len(request_bytes) >= 21
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

# Verify L01 HELLO request
request_bytes = await local_received_requests.get()
# Protocol is at offset 19 (2 bytes)
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
assert len(request_bytes) >= 21
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

assert local_received_requests.empty()

# Verify the channel switched to L01 protocol
assert local_channel.protocol_version == LocalProtocolVersion.L01.value

# We have established a connection. Now send some messages.
# Publish an L01 command. Currently the caller of the local channel needs to
# determine the protocol version to use, but this could be pushed inside of
# the channel in the future.
cmd_seq = 123
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
seq=cmd_seq,
payload=b'{"method":"get_status"}',
version=b"L01",
)
# Prepare a fake response to the command.
local_response_queue.put_nowait(
build_raw_response(
RoborockMessageProtocol.RPC_RESPONSE,
cmd_seq,
payload=b'{"status": "ok"}',
version=LocalProtocolVersion.L01,
connect_nonce=connect_nonce,
ack_nonce=TEST_RANDOM,
)
)

# Set up a subscriber to listen for the response then publish the message.
subscriber = Subscriber()
unsub = await local_channel.subscribe(subscriber.append)
await local_channel.publish(msg)

# Verify request received by the server
request_bytes = await local_received_requests.get()
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=TEST_RANDOM)
msgs = list(decoder(request_bytes))
assert len(msgs) == 1
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
assert msgs[0].payload == b'{"method":"get_status"}'
assert msgs[0].version == LocalProtocolVersion.L01.value.encode()

# Verify fake response published by the server, received by subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
response_message = subscriber.messages[0]
assert isinstance(response_message, RoborockMessage)
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
assert response_message.payload == b'{"status": "ok"}'
assert response_message.version == LocalProtocolVersion.L01.value.encode()

unsub()

assert snapshot == log
Loading