Skip to content

Commit e76c22d

Browse files
committed
fix: propagate torch sharing strategy to other workers
1 parent d501eb6 commit e76c22d

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- Sort files before iterating over a standoff or json folder to ensure reproducibility
2525
- Sentence detection now correctly match capitalized letters + apostrophe
2626
- We now ensure that the workers pool is properly closed whatever happens (exception, garbage collection, data ending) in the `multiprocessing` backend. This prevents some executions from hanging indefinitely at the end of the processing.
27+
- Propagate torch sharing strategy to other workers in the `multiprocessing` backend. This is useful when the system is running out of file descriptors and `ulimit -n` is not an option. Torch sharing strategy can also be set via an environment variable `TORCH_SHARING_STRATEGY` (default is `file_descriptor`, [consider using `file_system` if you encounter issues](https://pytorch.org/docs/stable/multiprocessing.html#file-system-file-system)).
2728

2829
### Data API changes
2930

edsnlp/processing/multiprocessing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,6 @@ def cpu_count(): # pragma: no cover
226226
):
227227
replace_pickler()
228228

229-
if os.environ.get("TORCH_SHARING_STRATEGY"): # pragma: no cover
230-
try:
231-
torch.multiprocessing.set_sharing_strategy(os.environ["TORCH_SHARING_STRATEGY"])
232-
except NameError:
233-
pass
234229

235230
try:
236231
import torch
@@ -312,6 +307,12 @@ def load(file, *args, map_location=None, **kwargs):
312307
dump = dill.dump
313308

314309

310+
if os.environ.get("TORCH_SHARING_STRATEGY"): # pragma: no cover
311+
try:
312+
torch.multiprocessing.set_sharing_strategy(os.environ["TORCH_SHARING_STRATEGY"])
313+
except NameError:
314+
pass
315+
315316
U = TypeVar("U")
316317
T = TypeVar("T")
317318

0 commit comments

Comments
 (0)