Skip to content

Commit 0e30bfa

Browse files
perf: make robust to network access patterns by caching header
1 parent 920f2d2 commit 0e30bfa

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

automated_test.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from mapbuffer import MapBuffer, HEADER_LENGTH
44
import random
5+
import mmap
56

67
@pytest.mark.parametrize("compress", (None, "gzip", "br", "zstd", "lzma"))
78
def test_empty(compress):
@@ -49,4 +50,46 @@ def test_full(compress):
4950

5051
mbuf.validate()
5152

52-
assert len(mbuf.buffer) > HEADER_LENGTH
53+
assert len(mbuf.buffer) > HEADER_LENGTH
54+
55+
@pytest.mark.parametrize("compress", (None, "gzip", "br", "zstd"))
56+
def test_mmap_access(compress):
57+
data = {
58+
1: b"hello",
59+
2: b"world",
60+
}
61+
mbuf = MapBuffer(data, compress=compress)
62+
with open("test_mmap.mb", "wb") as f:
63+
f.write(mbuf.tobytes())
64+
65+
with open("test_mmap.mb", "rb") as f:
66+
mb = MapBuffer(f)
67+
68+
assert mb[1] == b"hello"
69+
assert mb[2] == b"world"
70+
71+
@pytest.mark.parametrize("compress", (None, "gzip", "br", "zstd"))
72+
def test_object_access(compress):
73+
data = {
74+
1: b"hello",
75+
2: b"world",
76+
}
77+
mbuf = MapBuffer(data, compress=compress)
78+
79+
class Reader:
80+
def __init__(self):
81+
self.lst = mbuf.tobytes()
82+
def __getitem__(self, slc):
83+
return self.lst[slc]
84+
85+
mbuf2 = MapBuffer(Reader())
86+
assert mbuf2[1] == b"hello"
87+
assert mbuf2[2] == b"world"
88+
89+
90+
91+
92+
93+
94+
95+

mapbuffer/mapbuffer.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class MapBuffer:
1717
"""Represents a usable int->bytes dictionary as a byte string."""
1818
__slots__ = (
1919
"data", "tobytesfn", "frombytesfn",
20-
"dtype", "buffer", "_index", "_compress"
20+
"dtype", "buffer", "_header",
21+
"_index", "_compress"
2122
)
2223
def __init__(
2324
self, data=None, compress=None,
@@ -41,6 +42,7 @@ def __init__(
4142
self.dtype = np.uint64
4243
self.buffer = None
4344

45+
self._header = None
4446
self._index = None
4547
self._compress = None
4648

@@ -50,26 +52,31 @@ def __init__(
5052
self.buffer = mmap.mmap(data.fileno(), 0, prot=mmap.PROT_READ)
5153
elif isinstance(data, (bytes, mmap.mmap)):
5254
self.buffer = data
55+
elif hasattr(data, "__getitem__"):
56+
self.buffer = data
5357
else:
54-
raise TypeError("data must be a dict, bytes, file, or mmap. Got: " + str(type(data)))
58+
raise TypeError(
59+
f"data must be a dict, bytes, file, mmap, or otherwise support "
60+
f"__getitem__ with slice support for byte ranges. Got: {type(data)}"
61+
)
5562

5663
def __len__(self):
5764
"""Returns number of keys."""
58-
return int.from_bytes(self.buffer[12:16], byteorder="little", signed=False)
65+
return int.from_bytes(self.header[12:16], byteorder="little", signed=False)
5966

6067
@property
6168
def compress(self):
6269
if self._compress:
6370
return self._compress
6471

6572
self._compress = compression.normalize_encoding(
66-
self.buffer[8:12]
73+
self.header[8:12]
6774
)
6875
return self._compress
6976

7077
@property
7178
def format_version(self):
72-
return self.buffer[len(MAGIC_NUMBERS)]
79+
return self.header[len(MAGIC_NUMBERS)]
7380

7481
def __iter__(self):
7582
yield from self.keys()
@@ -78,6 +85,17 @@ def datasize(self):
7885
"""Returns size of data region in bytes."""
7986
return len(self.buffer) - HEADER_LENGTH - len(self) * 2 * 8
8087

88+
@property
89+
def header(self):
90+
"""Get the header bytes."""
91+
if self._header is not None:
92+
return self._header
93+
94+
# seems dumb, buf if self.buffer is an object that
95+
# requires network access, this is a valuable cache
96+
self._header = self.buffer[:HEADER_LENGTH]
97+
return self._header
98+
8199
def index(self):
82100
"""Get an Nx2 numpy array representing the index."""
83101
if self._index is not None:
@@ -217,11 +235,12 @@ def validate(self):
217235
@staticmethod
218236
def validate_buffer(buf):
219237
mapbuf = MapBuffer(buf)
238+
header = mapbuf.header
220239
index = mapbuf.index()
221240
if len(index) != len(mapbuf):
222241
raise ValidationError(f"Index size doesn't match. len(mapbuf): {len(mapbuf)}")
223242

224-
magic = buf[:len(MAGIC_NUMBERS)]
243+
magic = header[:len(MAGIC_NUMBERS)]
225244
if magic != MAGIC_NUMBERS:
226245
raise ValidationError(f"Magic number mismatch. Expected: {MAGIC_NUMBERS} Got: {magic}")
227246

0 commit comments

Comments
 (0)