|
1 | 1 | import struct |
2 | | -from collections import deque |
3 | 2 | from collections.abc import Iterator |
4 | 3 | from contextlib import suppress |
5 | | -from dataclasses import dataclass, field |
| 4 | +from dataclasses import dataclass |
6 | 5 | from typing import Any, Final, cast |
7 | 6 |
|
8 | 7 | from stompman.frames import ( |
@@ -141,53 +140,83 @@ def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: byte |
141 | 140 | return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg] |
142 | 141 |
|
143 | 142 |
|
144 | | -def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame: |
145 | | - command = bytes(lines.popleft()) |
146 | | - headers = {} |
147 | | - |
148 | | - while line := lines.popleft(): |
149 | | - header = parse_header(line) |
150 | | - if header and header[0] not in headers: |
151 | | - headers[header[0]] = header[1] |
152 | | - body = bytes(lines.popleft()) if lines else b"" |
153 | | - return make_frame_from_parts(command=command, headers=headers, body=body) |
154 | | - |
155 | | - |
156 | | -@dataclass(kw_only=True, slots=True) |
| 143 | +@dataclass(kw_only=True, slots=True, init=False) |
157 | 144 | class FrameParser: |
158 | | - _lines: deque[bytearray] = field(default_factory=deque, init=False) |
159 | | - _current_line: bytearray = field(default_factory=bytearray, init=False) |
160 | | - _previous_byte: bytes = field(default=b"", init=False) |
161 | | - _headers_processed: bool = field(default=False, init=False) |
| 145 | + _current_buf: bytearray |
| 146 | + _previous_byte: bytes | None |
| 147 | + _headers_processed: bool |
| 148 | + _command: bytes | None |
| 149 | + _headers: dict[str, str] |
| 150 | + _content_length: int | None |
| 151 | + |
| 152 | + def __init__(self) -> None: |
| 153 | + self._previous_byte = None |
| 154 | + self._reset() |
162 | 155 |
|
163 | 156 | def _reset(self) -> None: |
| 157 | + self._current_buf = bytearray() |
164 | 158 | self._headers_processed = False |
165 | | - self._lines.clear() |
166 | | - self._current_line = bytearray() |
| 159 | + self._command = None |
| 160 | + self._headers = {} |
| 161 | + self._content_length = None |
| 162 | + |
| 163 | + def _handle_null_byte(self) -> Iterator[AnyClientFrame | AnyServerFrame]: |
| 164 | + if not self._command or not self._headers_processed: |
| 165 | + self._reset() |
| 166 | + return |
| 167 | + if self._content_length is not None and self._content_length != len(self._current_buf): |
| 168 | + self._current_buf += NULL |
| 169 | + return |
| 170 | + yield make_frame_from_parts(command=self._command, headers=self._headers, body=bytes(self._current_buf)) |
| 171 | + self._reset() |
| 172 | + |
| 173 | + def _handle_newline_byte(self) -> Iterator[HeartbeatFrame]: |
| 174 | + if not self._current_buf and not self._command: |
| 175 | + yield HeartbeatFrame() |
| 176 | + return |
| 177 | + if self._previous_byte == CARRIAGE: |
| 178 | + self._current_buf.pop() |
| 179 | + self._headers_processed = not self._current_buf # extra empty line after headers |
| 180 | + |
| 181 | + if self._command: |
| 182 | + self._process_header() |
| 183 | + else: |
| 184 | + self._process_command() |
| 185 | + |
| 186 | + def _process_command(self) -> None: |
| 187 | + current_buf_bytes = bytes(self._current_buf) |
| 188 | + if current_buf_bytes not in COMMANDS_TO_FRAMES: |
| 189 | + self._reset() |
| 190 | + else: |
| 191 | + self._command = current_buf_bytes |
| 192 | + self._current_buf = bytearray() |
| 193 | + |
| 194 | + def _process_header(self) -> None: |
| 195 | + header = parse_header(self._current_buf) |
| 196 | + if not header: |
| 197 | + self._current_buf = bytearray() |
| 198 | + return |
| 199 | + header_key, header_value = header |
| 200 | + if header_key not in self._headers: |
| 201 | + self._headers[header_key] = header_value |
| 202 | + if header_key.lower() == "content-length": |
| 203 | + with suppress(ValueError): |
| 204 | + self._content_length = int(header_value) |
| 205 | + self._current_buf = bytearray() |
| 206 | + |
| 207 | + def _handle_body_byte(self, byte: bytes) -> None: |
| 208 | + if self._content_length is None or self._content_length != len(self._current_buf): |
| 209 | + self._current_buf += byte |
167 | 210 |
|
168 | 211 | def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame]: |
169 | 212 | for byte in iter_bytes(chunk): |
170 | 213 | if byte == NULL: |
171 | | - if self._headers_processed: |
172 | | - self._lines.append(self._current_line) |
173 | | - yield parse_lines_into_frame(self._lines) |
174 | | - self._reset() |
175 | | - |
176 | | - elif not self._headers_processed and byte == NEWLINE: |
177 | | - if self._current_line or self._lines: |
178 | | - if self._previous_byte == CARRIAGE: |
179 | | - self._current_line.pop() |
180 | | - self._headers_processed = not self._current_line # extra empty line after headers |
181 | | - |
182 | | - if not self._lines and bytes(self._current_line) not in COMMANDS_TO_FRAMES: |
183 | | - self._reset() |
184 | | - else: |
185 | | - self._lines.append(self._current_line) |
186 | | - self._current_line = bytearray() |
187 | | - else: |
188 | | - yield HeartbeatFrame() |
189 | | - |
| 214 | + yield from self._handle_null_byte() |
| 215 | + elif self._headers_processed: |
| 216 | + self._handle_body_byte(byte) |
| 217 | + elif byte == NEWLINE: |
| 218 | + yield from self._handle_newline_byte() |
190 | 219 | else: |
191 | | - self._current_line += byte |
| 220 | + self._current_buf += byte |
192 | 221 |
|
193 | 222 | self._previous_byte = byte |
0 commit comments