@@ -762,11 +762,15 @@ def process_items(self, stage):
762762 autocast = self .stream .autocast
763763 autocast_ctx = nullcontext ()
764764 device = self .devices [self .uid ]
765- if autocast :
766- autocast_ctx = torch .autocast (
767- device_type = getattr (device , "type" , device ).split (":" )[0 ],
768- dtype = autocast if autocast is not True else None ,
769- )
765+ device_type = getattr (device , "type" , device ).split (":" )[0 ]
766+ try :
767+ if autocast :
768+ autocast_ctx = torch .autocast (
769+ device_type = device_type ,
770+ dtype = autocast if autocast is not True else None ,
771+ )
772+ except RuntimeError : # pragma: no cover
773+ pass
770774
771775 with torch .no_grad (), autocast_ctx , torch .inference_mode ():
772776 for item in self .iter_tasks (stage ):
@@ -1249,7 +1253,12 @@ def adjust_num_workers(stream: Stream):
12491253 num_gpu_workers = 0
12501254
12511255 max_cpu_workers = max (num_cpus - num_gpu_workers - 1 , 0 )
1252- default_cpu_workers = max (min (max_cpu_workers , num_gpu_workers * 4 ), 1 )
1256+ default_cpu_workers = max (
1257+ min (max_cpu_workers , num_gpu_workers * 4 )
1258+ if num_gpu_workers > 0
1259+ else max_cpu_workers ,
1260+ 1 ,
1261+ )
12531262 num_cpu_workers = (
12541263 default_cpu_workers
12551264 if stream .num_cpu_workers is None
0 commit comments