Skip to content

Commit 81ff9ed

Browse files
giyokunvrslev
andauthored
Improve binary message handling (#145)
Co-authored-by: Lev Vereshchagin <mail@vrslev.com>
1 parent 3f689f0 commit 81ff9ed

File tree

2 files changed

+90
-41
lines changed

2 files changed

+90
-41
lines changed

packages/stompman/stompman/serde.py

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import struct
2-
from collections import deque
32
from collections.abc import Iterator
43
from contextlib import suppress
5-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
65
from typing import Any, Final, cast
76

87
from stompman.frames import (
@@ -141,53 +140,83 @@ def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: byte
141140
return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg]
142141

143142

144-
def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:
145-
command = bytes(lines.popleft())
146-
headers = {}
147-
148-
while line := lines.popleft():
149-
header = parse_header(line)
150-
if header and header[0] not in headers:
151-
headers[header[0]] = header[1]
152-
body = bytes(lines.popleft()) if lines else b""
153-
return make_frame_from_parts(command=command, headers=headers, body=body)
154-
155-
156-
@dataclass(kw_only=True, slots=True)
143+
@dataclass(kw_only=True, slots=True, init=False)
157144
class FrameParser:
158-
_lines: deque[bytearray] = field(default_factory=deque, init=False)
159-
_current_line: bytearray = field(default_factory=bytearray, init=False)
160-
_previous_byte: bytes = field(default=b"", init=False)
161-
_headers_processed: bool = field(default=False, init=False)
145+
_current_buf: bytearray
146+
_previous_byte: bytes | None
147+
_headers_processed: bool
148+
_command: bytes | None
149+
_headers: dict[str, str]
150+
_content_length: int | None
151+
152+
def __init__(self) -> None:
153+
self._previous_byte = None
154+
self._reset()
162155

163156
def _reset(self) -> None:
157+
self._current_buf = bytearray()
164158
self._headers_processed = False
165-
self._lines.clear()
166-
self._current_line = bytearray()
159+
self._command = None
160+
self._headers = {}
161+
self._content_length = None
162+
163+
def _handle_null_byte(self) -> Iterator[AnyClientFrame | AnyServerFrame]:
164+
if not self._command or not self._headers_processed:
165+
self._reset()
166+
return
167+
if self._content_length is not None and self._content_length != len(self._current_buf):
168+
self._current_buf += NULL
169+
return
170+
yield make_frame_from_parts(command=self._command, headers=self._headers, body=bytes(self._current_buf))
171+
self._reset()
172+
173+
def _handle_newline_byte(self) -> Iterator[HeartbeatFrame]:
174+
if not self._current_buf and not self._command:
175+
yield HeartbeatFrame()
176+
return
177+
if self._previous_byte == CARRIAGE:
178+
self._current_buf.pop()
179+
self._headers_processed = not self._current_buf # extra empty line after headers
180+
181+
if self._command:
182+
self._process_header()
183+
else:
184+
self._process_command()
185+
186+
def _process_command(self) -> None:
187+
current_buf_bytes = bytes(self._current_buf)
188+
if current_buf_bytes not in COMMANDS_TO_FRAMES:
189+
self._reset()
190+
else:
191+
self._command = current_buf_bytes
192+
self._current_buf = bytearray()
193+
194+
def _process_header(self) -> None:
195+
header = parse_header(self._current_buf)
196+
if not header:
197+
self._current_buf = bytearray()
198+
return
199+
header_key, header_value = header
200+
if header_key not in self._headers:
201+
self._headers[header_key] = header_value
202+
if header_key.lower() == "content-length":
203+
with suppress(ValueError):
204+
self._content_length = int(header_value)
205+
self._current_buf = bytearray()
206+
207+
def _handle_body_byte(self, byte: bytes) -> None:
208+
if self._content_length is None or self._content_length != len(self._current_buf):
209+
self._current_buf += byte
167210

168211
def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame]:
169212
for byte in iter_bytes(chunk):
170213
if byte == NULL:
171-
if self._headers_processed:
172-
self._lines.append(self._current_line)
173-
yield parse_lines_into_frame(self._lines)
174-
self._reset()
175-
176-
elif not self._headers_processed and byte == NEWLINE:
177-
if self._current_line or self._lines:
178-
if self._previous_byte == CARRIAGE:
179-
self._current_line.pop()
180-
self._headers_processed = not self._current_line # extra empty line after headers
181-
182-
if not self._lines and bytes(self._current_line) not in COMMANDS_TO_FRAMES:
183-
self._reset()
184-
else:
185-
self._lines.append(self._current_line)
186-
self._current_line = bytearray()
187-
else:
188-
yield HeartbeatFrame()
189-
214+
yield from self._handle_null_byte()
215+
elif self._headers_processed:
216+
self._handle_body_byte(byte)
217+
elif byte == NEWLINE:
218+
yield from self._handle_newline_byte()
190219
else:
191-
self._current_line += byte
220+
self._current_buf += byte
192221

193222
self._previous_byte = byte

packages/stompman/test_stompman/test_frame_serde.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None:
222222
ConnectedFrame(headers={"header": "1.2"}),
223223
],
224224
),
225+
# Correct content-length with body containing NULL byte
226+
(
227+
b"MESSAGE\ncontent-length:5\n\nBod\x00y\x00",
228+
[MessageFrame(headers={"content-length": "5"}, body=b"Bod\x00y")],
229+
),
230+
# Content-length shorter than actual body (should only read up to content-length)
231+
(
232+
b"MESSAGE\ncontent-length:4\n\nBody\x00 with extra\x00\n",
233+
[MessageFrame(headers={"content-length": "4"}, body=b"Body"), HeartbeatFrame()],
234+
),
235+
# Content-length longer than actual body (should wait for more data)
236+
(
237+
b"MESSAGE\ncontent-length:10\n\nShort",
238+
[],
239+
),
240+
# Content-length longer than actual body, then more data comes with NULL terminator
241+
(
242+
b"MESSAGE\ncontent-length:10\n\nShortMOREDATA\x00",
243+
[MessageFrame(headers={"content-length": "10"}, body=b"ShortMORED")],
244+
),
225245
],
226246
)
227247
def test_load_frames(raw_frames: bytes, loaded_frames: list[AnyServerFrame]) -> None:

0 commit comments

Comments
 (0)