Skip to content

Commit 9d1a640

Browse files
committed
feat: enable (and expose) torch autocast and inference_mode
1 parent b73336a commit 9d1a640

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed

changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
### Added
66

7+
- `data.set_processing(...)` now expose an `autocast` parameter to disable or tweak the automatic casting of the tensor
8+
during the processing. Autocasting should result in a slight speedup, but may lead to numerical instability.
9+
- Use `torch.inference_mode` to disable view tracking and version counter bumps during inference.
10+
711
### Changed
812

913
### Fixed

edsnlp/core/lazy_collection.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def gpu_worker_devices(self):
131131
def cpu_worker_devices(self):
132132
return self.config.get("cpu_worker_devices")
133133

134+
@property
135+
def autocast(self):
136+
return self.config.get("autocast")
137+
134138
@property
135139
def backend(self):
136140
backend = self.config.get("backend")
@@ -156,8 +160,9 @@ def set_processing(
156160
num_gpu_workers: Optional[int] = INFER,
157161
disable_implicit_parallelism: bool = True,
158162
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER,
159-
gpu_pipe_names: Optional[List[str]] = INFER,
163+
autocast: Union[bool, Any] = True,
160164
show_progress: bool = False,
165+
gpu_pipe_names: Optional[List[str]] = INFER,
161166
process_start_method: Optional[Literal["fork", "spawn"]] = INFER,
162167
gpu_worker_devices: Optional[List[str]] = INFER,
163168
cpu_worker_devices: Optional[List[str]] = INFER,
@@ -203,10 +208,6 @@ def set_processing(
203208
disable_implicit_parallelism: bool
204209
Whether to disable OpenMP and Huggingface tokenizers implicit parallelism in
205210
multiprocessing mode. Defaults to True.
206-
gpu_pipe_names: Optional[List[str]]
207-
List of pipe names to accelerate on a GPUWorker, defaults to all pipes
208-
that inherit from TorchComponent. Only used with "multiprocessing" backend.
209-
Inferred from the pipeline if not set.
210211
backend: Optional[Literal["simple", "multiprocessing", "spark"]]
211212
The backend to use for parallel processing. If not set, the backend is
212213
automatically selected based on the input data and the number of workers.
@@ -217,9 +218,20 @@ def set_processing(
217218
`num_gpu_workers` is greater than 0.
218219
- "spark" is used when the input data is a Spark dataframe and the output
219220
writer is a Spark writer.
221+
autocast: Union[bool, Any]
222+
Whether to use
223+
[automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html)
224+
for the forward pass of the deep-learning components. If True (by default),
225+
AMP will be used with the default settings. If False, AMP will not be used.
226+
If a dtype is provided, it will be passed to the `torch.autocast` context
227+
manager.
220228
show_progress: Optional[bool]
221229
Whether to show progress bars (only applicable with "simple" and
222230
"multiprocessing" backends).
231+
gpu_pipe_names: Optional[List[str]]
232+
List of pipe names to accelerate on a GPUWorker, defaults to all pipes
233+
that inherit from TorchComponent. Only used with "multiprocessing" backend.
234+
Inferred from the pipeline if not set.
223235
process_start_method: Optional[Literal["fork", "spawn"]]
224236
Whether to use "fork" or "spawn" as the start method for the multiprocessing
225237
backend. The default is "fork" on Unix systems and "spawn" on Windows.

edsnlp/processing/multiprocessing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,13 +656,22 @@ def run(self):
656656
if name in self.gpu_pipe_names
657657
]
658658

659+
autocast_ctx = (
660+
torch.autocast(
661+
device_type=self.device,
662+
dtype=lc.autocast,
663+
)
664+
if lc.autocast is not None
665+
else nullcontext()
666+
)
667+
659668
del lc
660669
logging.info(f"Starting {self} on {os.getpid()}")
661670

662671
# Inform the main process that we are ready
663672
self.exchanger.put_results((None, 0, None, None))
664673

665-
with torch.no_grad(): # , torch.cuda.amp.autocast():
674+
with torch.no_grad(), autocast_ctx, torch.inference_mode():
666675
while True:
667676
stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
668677
if task is None:

edsnlp/processing/simple.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,25 @@ def execute_simple_backend(
2525
batch on the current process in a sequential manner.
2626
"""
2727
try:
28-
no_grad = sys.modules["torch"].no_grad
28+
torch = sys.modules["torch"]
29+
no_grad_ctx = torch.no_grad()
30+
autocast_ctx = (
31+
torch.autocast(
32+
device_type=next(
33+
p.device for pipe in lc.torch_components for p in lc.parameters
34+
),
35+
dtype=lc.autocast,
36+
)
37+
if lc.autocast is not None
38+
else nullcontext()
39+
)
40+
inference_mode_ctx = (
41+
torch.inference_mode()
42+
if hasattr(torch, "inference_mode")
43+
else nullcontext()
44+
)
2945
except (KeyError, AttributeError):
30-
no_grad = nullcontext
46+
no_grad_ctx = autocast_ctx = inference_mode_ctx = nullcontext()
3147
reader = lc.reader
3248
writer = lc.writer
3349
show_progress = lc.show_progress
@@ -48,7 +64,7 @@ def process():
4864

4965
bar = tqdm(smoothing=0.1, mininterval=5.0)
5066

51-
with bar, lc.eval():
67+
with bar, lc.eval(), autocast_ctx, inference_mode_ctx:
5268
for docs in batchify(
5369
(
5470
subtask
@@ -64,7 +80,7 @@ def process():
6480

6581
for batch in batchify_fns[lc.batch_by](docs, lc.batch_size):
6682
count = len(batch)
67-
with no_grad(), lc.cache():
83+
with no_grad_ctx, lc.cache():
6884
batch = apply_basic_pipes(batch, batch_components)
6985

7086
if writer is not None:

0 commit comments

Comments
 (0)