Skip to content

Commit af06396

Browse files
committed
fix: check autocast is supported on device
1 parent 874d36a commit af06396

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

edsnlp/processing/multiprocessing.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -762,11 +762,15 @@ def process_items(self, stage):
762762
autocast = self.stream.autocast
763763
autocast_ctx = nullcontext()
764764
device = self.devices[self.uid]
765-
if autocast:
766-
autocast_ctx = torch.autocast(
767-
device_type=getattr(device, "type", device).split(":")[0],
768-
dtype=autocast if autocast is not True else None,
769-
)
765+
device_type = getattr(device, "type", device).split(":")[0]
766+
try:
767+
if autocast:
768+
autocast_ctx = torch.autocast(
769+
device_type=device_type,
770+
dtype=autocast if autocast is not True else None,
771+
)
772+
except RuntimeError: # pragma: no cover
773+
pass
770774

771775
with torch.no_grad(), autocast_ctx, torch.inference_mode():
772776
for item in self.iter_tasks(stage):
@@ -1249,7 +1253,12 @@ def adjust_num_workers(stream: Stream):
12491253
num_gpu_workers = 0
12501254

12511255
max_cpu_workers = max(num_cpus - num_gpu_workers - 1, 0)
1252-
default_cpu_workers = max(min(max_cpu_workers, num_gpu_workers * 4), 1)
1256+
default_cpu_workers = max(
1257+
min(max_cpu_workers, num_gpu_workers * 4)
1258+
if num_gpu_workers > 0
1259+
else max_cpu_workers,
1260+
1,
1261+
)
12531262
num_cpu_workers = (
12541263
default_cpu_workers
12551264
if stream.num_cpu_workers is None

edsnlp/processing/simple.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,22 @@ def execute_simple_backend(stream: Stream):
2020
try:
2121
torch = sys.modules["torch"]
2222
no_grad_ctx = torch.no_grad()
23-
autocast_device_type = next(
23+
device = next(
2424
(p.device for pipe in stream.torch_components() for p in pipe.parameters()),
2525
torch.device("cpu"),
26-
).type.split(":")[0]
27-
autocast_dtype = stream.autocast if stream.autocast is not True else None
28-
autocast_ctx = (
29-
torch.autocast(
30-
device_type=autocast_device_type,
31-
dtype=autocast_dtype,
32-
)
33-
if stream.autocast
34-
else nullcontext()
3526
)
27+
device_type = getattr(device, "type", device).split(":")[0]
28+
autocast = stream.autocast
29+
autocast_ctx = nullcontext()
30+
try:
31+
if autocast:
32+
autocast_ctx = torch.autocast(
33+
device_type=device_type,
34+
dtype=autocast if autocast is not True else None,
35+
)
36+
except RuntimeError: # pragma: no cover
37+
pass
38+
3639
inference_mode_ctx = (
3740
torch.inference_mode()
3841
if hasattr(torch, "inference_mode")

0 commit comments

Comments
 (0)