Skip to content

Commit 125d90f

Browse files
committed
fix: support sentinel based write_in_worker in mp backend
1 parent d85b5c1 commit 125d90f

File tree

11 files changed

+236
-106
lines changed

11 files changed

+236
-106
lines changed

changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
- Readers now have a `loop` parameter to cycle over the data indefinitely (useful for training)
3737
- Readers now have a `shuffle` parameter to shuffle the data before iterating over it
3838
- In `multiprocessing` mode, file based readers now read the data in the workers (was an option before)
39+
- We now support two new special batch sizes
40+
- "fragment" in the case of parquet datasets: rows of a full parquet file fragment per batch
41+
- "dataset" which is mostly useful during training, for instance to shuffle the dataset at each epoch.
42+
These are also compatible in batched writer such as parquet, where each input fragment can be processed and mapped to a single matching output fragment.
3943
- :boom: Breaking change: a `map` function returning a list or a generator won't be automatically flattened anymore. Use `flatten()` to flatten the output if needed. This shouldn't change the behavior for most users since most writers (to_pandas, to_polars, to_parquet, ...) still flatten the output
4044
- :boom: Breaking change: the `chunk_size` and `sort_chunks` are now deprecated : to sort data before applying a transformation, use `.map_batches(custom_sort_fn, batch_size=...)`
4145

edsnlp/core/stream.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ class Op(abc.ABC):
8888
def __call__(self, items):
8989
raise NotImplementedError()
9090

91-
@property
92-
def expected_sentinels(self):
93-
return set()
94-
9591

9692
class FlattenOp(Op):
9793
elementwise = False
@@ -150,10 +146,6 @@ def __repr__(self):
150146
f"sentinel_mode={self.sentinel_mode})"
151147
)
152148

153-
@property
154-
def expected_sentinels(self):
155-
return getattr(self.batch_fn, "expected_sentinels", set())
156-
157149

158150
class MapOp(Op):
159151
def __init__(self, pipe, kwargs):
@@ -966,7 +958,13 @@ def _make_stages(self, split_torch_pipes: bool) -> List[Stage]:
966958

967959
def validate_ops(self, ops, update: bool = False):
968960
# Check batchify requirements
969-
expected_sentinels = set()
961+
requires_sentinels = set()
962+
963+
if hasattr(self.writer, "batch_fn") and hasattr(
964+
self.writer.batch_fn, "requires_sentinel"
965+
):
966+
requires_sentinels.add(self.writer.batch_fn.requires_sentinel)
967+
970968
self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by)
971969
for op in reversed(ops):
972970
if isinstance(op, BatchifyOp):
@@ -977,29 +975,38 @@ def validate_ops(self, ops, update: bool = False):
977975
else None
978976
)
979977
if sentinel_mode == "auto":
980-
sentinel_mode = "split" if expected_sentinels else "drop"
981-
if expected_sentinels and op.sentinel_mode == "drop":
978+
sentinel_mode = "split" if requires_sentinels else "drop"
979+
if requires_sentinels and op.sentinel_mode == "drop":
982980
raise ValueError(
983981
f"Operation {op} drops the stream sentinel values "
984982
f"(markers for the end of a dataset or a dataset "
985983
f"fragment), but some downstream operation(s) require "
986-
f"the following sentinel values: {expected_sentinels}. "
984+
f"the following sentinel values: {requires_sentinels}. "
987985
f"Ensure that you do not set `sentinel_mode='drop'` on "
988986
f"any upstream batching operation."
989987
)
990-
expected_sentinels.update(op.expected_sentinels)
991988
if update:
992989
op.sentinel_mode = sentinel_mode
993990

994-
if expected_sentinels and (self.backend == "spark" or not self.deterministic):
991+
if hasattr(batch_fn, "requires_sentinel"):
992+
requires_sentinels.add(batch_fn.requires_sentinel)
993+
994+
sentinel_str = ", ".join(requires_sentinels)
995+
if requires_sentinels and self.backend == "spark":
995996
raise ValueError(
996-
f"Some operations require sentinel values ({expected_sentinels}), "
997+
f"Some operations require sentinel values ({sentinel_str}), "
997998
f"but the Spark backend does not support sentinel values."
998999
)
999-
if not (expected_sentinels < self.reader.emitted_sentinels):
1000+
if requires_sentinels and not self.deterministic:
1001+
raise ValueError(
1002+
f"Some operations require sentinel values ({sentinel_str}), "
1003+
f"but these are not supported in when `deterministic=False`."
1004+
)
1005+
if not (requires_sentinels <= self.reader.emitted_sentinels):
10001006
raise ValueError(
1001-
f"Some operations require sentinel values ({expected_sentinels}), "
1002-
f"but the reader does not emit these values."
1007+
f"Some operations require sentinel values ({sentinel_str}), "
1008+
f"but the reader does not emit these values "
1009+
f"({', '.join(self.reader.emitted_sentinels)})."
10031010
)
10041011

10051012
def __repr__(self):

edsnlp/data/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def consolidate(self, items: Iterable):
9090

9191
class BatchWriter(BaseWriter):
9292
batch_size: Optional[int] = None
93-
batch_by: Callable
94-
batch_in_worker: bool = False
93+
batch_fn: Callable
94+
write_in_worker: bool = False
9595

9696
def handle_batch(self, batch):
9797
raise NotImplementedError()

edsnlp/data/parquet.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from edsnlp.core.stream import Stream
1515
from edsnlp.data.base import BatchWriter, FileBasedReader
1616
from edsnlp.data.converters import get_dict2doc_converter, get_doc2dict_converter
17-
from edsnlp.utils.batching import batchify_fns
17+
from edsnlp.utils.batching import BatchBy, batchify_fns
1818
from edsnlp.utils.collections import batchify, dl_to_ld, flatten, ld_to_dl, shuffle
1919
from edsnlp.utils.file_system import FileSystem, normalize_fs_path
2020
from edsnlp.utils.stream_sentinels import DatasetEndSentinel, FragmentEndSentinel
@@ -61,12 +61,16 @@ def read_records(self) -> Iterable[Any]:
6161
for file in shuffle(files, self.rng):
6262
records = shuffle(self.read_fragment(file), self.rng)
6363
yield from records
64-
yield FragmentEndSentinel(file)
65-
else:
64+
yield FragmentEndSentinel(file.path)
65+
elif self.shuffle == "dataset":
6666
records = (line for file in files for line in self.read_fragment(file))
67-
if self.shuffle == "dataset":
68-
records = shuffle(records, self.rng)
67+
records = shuffle(records, self.rng)
6968
yield from records
69+
else:
70+
for file in files:
71+
records = list(self.read_fragment(file))
72+
yield from records
73+
yield FragmentEndSentinel(file.path)
7074
yield DatasetEndSentinel()
7175
if not self.loop:
7276
break
@@ -85,9 +89,9 @@ def __init__(
8589
self,
8690
*,
8791
path: Union[str, Path],
88-
batch_size: Optional[Union[int]] = None,
89-
batch_by: Union[Callable, Literal["docs"]] = "docs",
90-
batch_in_worker: bool = False,
92+
batch_size: Optional[Union[int, str]] = None,
93+
batch_by: BatchBy = None,
94+
write_in_worker: bool = False,
9195
overwrite: bool,
9296
filesystem: Optional[FileSystem] = None,
9397
):
@@ -113,21 +117,18 @@ def __init__(
113117
for file in dataset.files:
114118
self.fs.rm_file(file)
115119
self.fs = filesystem
116-
assert batch_by is None or batch_by in batchify_fns or callable(batch_by)
117-
self.batch_by = batchify_fns.get(batch_by, batch_by)
118-
if (
119-
batch_by in ("docs", "doc")
120-
or self.batch_by is batchify
121-
and batch_size is None
122-
):
120+
batch_size, batch_by = Stream.validate_batching(batch_size, batch_by)
121+
if batch_by in ("docs", "doc", None, batchify) and batch_size is None:
123122
warnings.warn(
124123
"You should specify a batch size when using record-wise batch writing. "
125124
"Setting batch size to 1024."
126125
)
127126
batch_size = 1024
127+
batch_by = batch_by or "docs"
128+
self.batch_fn = batchify_fns.get(batch_by, batch_by)
128129

129130
self.batch_size = batch_size
130-
self.batch_in_worker = batch_in_worker
131+
self.write_in_worker = write_in_worker
131132
self.batch = []
132133
self.closed = False
133134

@@ -250,9 +251,9 @@ def write_parquet(
250251
data: Union[Any, Stream],
251252
path: Union[str, Path],
252253
*,
253-
batch_size: Optional[int] = None,
254-
batch_by: Union[Callable, Literal["docs"]] = "docs",
255-
batch_in_worker: bool = True,
254+
batch_size: Optional[Union[int, str]] = None,
255+
batch_by: BatchBy = None,
256+
write_in_worker: bool = True,
256257
overwrite: bool = False,
257258
filesystem: Optional[FileSystem] = None,
258259
execute: bool = True,
@@ -295,15 +296,17 @@ def write_parquet(
295296
The method to batch the documents. If "docs", the batch size is the number of
296297
documents. If "fragment", each batch corresponds to a parquet file fragment from
297298
the input data.
298-
batch_in_worker: bool
299-
In multiprocessing or spark mode, whether to batch the documents in the workers
300-
or in the main process.
299+
write_in_worker: bool
300+
In multiprocessing or spark mode, whether to batch and write the documents in
301+
the workers or in the main process.
301302
302303
For instance, a worker may read the 1st, 3rd, 5th, ... documents, while another
303-
reads the 2nd, 4th, 6th, ... documents. If `batch_in_worker` is False and
304-
`deterministic` is True (default), the original order of the documents will be
305-
recovered in the main process, and batching there can produce fragments that
306-
respect the original order.
304+
reads the 2nd, 4th, 6th, ... documents.
305+
306+
If `write_in_worker` is False, `deterministic` is True (default) and no
307+
operation adds or remove document from the stream (e.g., no `map_batches`), the
308+
original order of the documents will be recovered in the main process, and
309+
batching there can produce fragments that respect the original order.
307310
overwrite: bool
308311
Whether to overwrite existing directories.
309312
filesystem: Optional[AbstractFileSystem] = None,
@@ -326,14 +329,10 @@ def write_parquet(
326329
batch_size is None
327330
), "Cannot specify both 'batch_size' and deprecated 'num_rows_per_file'."
328331
batch_size = kwargs.pop("num_rows_per_file")
329-
assert batch_by == "docs", "Cannot use 'num_rows_per_file' with 'batch_by'."
330-
if "write_in_worker" in kwargs:
331-
warnings.warn(
332-
"The 'write_in_worker' parameter is deprecated. To perform "
333-
"batching in the worker processes, set 'batch_in_worker=True'.",
334-
VisibleDeprecationWarning,
335-
)
336-
batch_in_worker = kwargs.pop("write_in_worker")
332+
assert batch_by in (
333+
None,
334+
"docs",
335+
), "Cannot use 'num_rows_per_file' with 'batch_by'."
337336
if "accumulate" in kwargs:
338337
warnings.warn(
339338
"The 'accumulate' parameter is deprecated.", VisibleDeprecationWarning
@@ -347,7 +346,7 @@ def write_parquet(
347346
path=path,
348347
batch_size=batch_size,
349348
batch_by=batch_by,
350-
batch_in_worker=batch_in_worker,
349+
write_in_worker=write_in_worker,
351350
overwrite=overwrite,
352351
filesystem=filesystem,
353352
),

edsnlp/processing/multiprocessing.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -705,16 +705,27 @@ def preprocess_before_forward(self, items, stage):
705705
def send_results(self, items):
706706
writer = self.stream.writer
707707
if writer is not None:
708-
items = (writer.handle_record(rec) for rec in items)
709-
if getattr(writer, "batch_in_worker", None) is True:
710-
items = writer.batch_by(
708+
items = (
709+
writer.handle_record(rec)
710+
if not isinstance(rec, StreamSentinel)
711+
else rec
712+
for rec in items
713+
)
714+
if getattr(writer, "write_in_worker", None) is True:
715+
items = writer.batch_fn(
711716
items,
712717
batch_size=writer.batch_size,
713718
sentinel_mode="drop",
714719
)
715-
items = (writer.handle_batch(b) for b in items)
720+
items = (
721+
writer.handle_batch(b)
722+
for b in items
723+
if not isinstance(b, StreamSentinel)
724+
)
716725
else:
717-
items = ((x, 1) for x in items if not isinstance(x, StreamSentinel))
726+
items = (
727+
(x, 1) if not isinstance(x, StreamSentinel) else (x, 0) for x in items
728+
)
718729

719730
name = f"from-{self.uid}_to-main"
720731
queue = self.data_queues[name]
@@ -1024,11 +1035,15 @@ def run(self):
10241035
# Create the main iterator
10251036
items = self.dequeue_outputs()
10261037
writer = self.stream.writer
1027-
if getattr(writer, "batch_in_worker", None) is False:
1038+
if getattr(writer, "write_in_worker", None) is False:
10281039
writer: BatchWriter
1029-
items = writer.batch_by(items, writer.batch_size)
1040+
items = writer.batch_fn(items, writer.batch_size, sentinel_mode="drop")
10301041
# get the 1st element (2nd is the count)
1031-
items = (writer.handle_batch(b)[0] for b in items)
1042+
items = (
1043+
writer.handle_batch(b)[0]
1044+
for b in items
1045+
if not isinstance(b, StreamSentinel)
1046+
)
10321047

10331048
# If we are garbage collected, stop the execution
10341049
weakref.finalize(items, self.teardown, garbage_collected=True)
@@ -1060,6 +1075,13 @@ def dequeue_outputs(self):
10601075

10611076
def iter_outputs(self, stop_mode=False):
10621077
deterministic = self.stream.deterministic
1078+
requires_sentinel = (
1079+
hasattr(self.stream.writer, "batch_fn")
1080+
and getattr(self.stream.writer.batch_fn, "requires_sentinel", None)
1081+
and not self.stream.writer.write_in_worker
1082+
)
1083+
missing_sentinels = len(self.cpu_worker_names) if requires_sentinel else 0
1084+
buffer = []
10631085
while self.num_alive_workers > 0:
10641086
if self.stopped and not stop_mode: # pragma: no cover
10651087
raise StopSignal()
@@ -1097,9 +1119,20 @@ def iter_outputs(self, stop_mode=False):
10971119
self.num_alive_workers -= 1
10981120
self.workers_status[worker_idx] = False
10991121
continue
1100-
if isinstance(out, StreamSentinel) and worker_idx > 0:
1122+
if isinstance(out[0], StreamSentinel):
1123+
if out[0].kind == requires_sentinel:
1124+
missing_sentinels -= 1
1125+
if missing_sentinels == 0:
1126+
yield from buffer
1127+
yield out
1128+
buffer.clear()
1129+
missing_sentinels = len(self.cpu_worker_names)
11011130
continue
1102-
yield out
1131+
if requires_sentinel:
1132+
buffer.append(out)
1133+
else:
1134+
yield out
1135+
yield from buffer
11031136
if self.error:
11041137
raise self.error
11051138

edsnlp/processing/simple.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ def process():
8080
for item in items
8181
)
8282

83-
if getattr(writer, "batch_by", None) is not None:
84-
items = writer.batch_by(items, writer.batch_size, sentinel_mode="drop")
83+
if getattr(writer, "batch_fn", None) is not None:
84+
items = writer.batch_fn(items, writer.batch_size, sentinel_mode="drop")
8585
# get the 1st element (2nd is the count)
8686
for b in items:
87-
item, count = writer.handle_batch(b)
88-
bar.update(count)
87+
if not isinstance(b, StreamSentinel):
88+
item, count = writer.handle_batch(b)
89+
bar.update(count)
8990
yield item
9091
else:
9192
for item in items:

edsnlp/processing/spark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def process_partition(items): # pragma: no cover
132132
items = (writer.handle_record(item) for item in items)
133133

134134
results = []
135-
if getattr(writer, "batch_in_worker", None) is True:
136-
items = writer.batch_by(items, writer.batch_size)
135+
if getattr(writer, "write_in_worker", None) is True:
136+
items = writer.batch_fn(items, writer.batch_size)
137137
# get the 1st element (2nd is the count)
138138
for item in items:
139139
item, count = writer.handle_batch(item)
@@ -163,9 +163,9 @@ def process_partition(items): # pragma: no cover
163163
for item in df.rdd.mapPartitions(process_partition).toLocalIterator()
164164
)
165165

166-
if getattr(writer, "batch_in_worker", None) is False:
166+
if getattr(writer, "write_in_worker", None) is False:
167167
writer: BatchWriter
168-
items = writer.batch_by(items, writer.batch_size)
168+
items = writer.batch_fn(items, writer.batch_size)
169169
# get the 1st element (2nd is the count)
170170
items = (writer.handle_batch(b)[0] for b in items)
171171
return items if writer is None else writer.consolidate(items)

0 commit comments

Comments
 (0)