@@ -131,6 +131,10 @@ def gpu_worker_devices(self):
131131 def cpu_worker_devices (self ):
132132 return self .config .get ("cpu_worker_devices" )
133133
134+ @property
135+ def autocast (self ):
136+ return self .config .get ("autocast" )
137+
134138 @property
135139 def backend (self ):
136140 backend = self .config .get ("backend" )
@@ -156,8 +160,9 @@ def set_processing(
156160 num_gpu_workers : Optional [int ] = INFER ,
157161 disable_implicit_parallelism : bool = True ,
158162 backend : Optional [Literal ["simple" , "multiprocessing" , "mp" , "spark" ]] = INFER ,
159- gpu_pipe_names : Optional [ List [ str ]] = INFER ,
163+ autocast : Union [ bool , Any ] = True ,
160164 show_progress : bool = False ,
165+ gpu_pipe_names : Optional [List [str ]] = INFER ,
161166 process_start_method : Optional [Literal ["fork" , "spawn" ]] = INFER ,
162167 gpu_worker_devices : Optional [List [str ]] = INFER ,
163168 cpu_worker_devices : Optional [List [str ]] = INFER ,
@@ -203,10 +208,6 @@ def set_processing(
203208 disable_implicit_parallelism: bool
204209 Whether to disable OpenMP and Huggingface tokenizers implicit parallelism in
205210 multiprocessing mode. Defaults to True.
206- gpu_pipe_names: Optional[List[str]]
207- List of pipe names to accelerate on a GPUWorker, defaults to all pipes
208- that inherit from TorchComponent. Only used with "multiprocessing" backend.
209- Inferred from the pipeline if not set.
210211 backend: Optional[Literal["simple", "multiprocessing", "spark"]]
211212 The backend to use for parallel processing. If not set, the backend is
212213 automatically selected based on the input data and the number of workers.
@@ -217,9 +218,20 @@ def set_processing(
217218 `num_gpu_workers` is greater than 0.
218219 - "spark" is used when the input data is a Spark dataframe and the output
219220 writer is a Spark writer.
221+ autocast: Union[bool, Any]
222+ Whether to use
223+ [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html)
224+ for the forward pass of the deep-learning components. If True (by default),
225+ AMP will be used with the default settings. If False, AMP will not be used.
226+ If a dtype is provided, it will be passed to the `torch.autocast` context
227+ manager.
220228 show_progress: Optional[bool]
221229 Whether to show progress bars (only applicable with "simple" and
222230 "multiprocessing" backends).
231+ gpu_pipe_names: Optional[List[str]]
232+ List of pipe names to accelerate on a GPUWorker, defaults to all pipes
233+ that inherit from TorchComponent. Only used with "multiprocessing" backend.
234+ Inferred from the pipeline if not set.
223235 process_start_method: Optional[Literal["fork", "spawn"]]
224236 Whether to use "fork" or "spawn" as the start method for the multiprocessing
225237 backend. The default is "fork" on Unix systems and "spawn" on Windows.
0 commit comments