@@ -94,7 +94,7 @@ def alpha_bar_fn(t):
9494
9595
9696# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97- def rescale_zero_terminal_snr (betas ) :
97+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
9898 """
9999 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
100100
@@ -144,16 +144,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
144144 The starting `beta` value of inference.
145145 beta_end (`float`, defaults to 0.02):
146146 The final `beta` value.
147- beta_schedule (`str `, defaults to `"linear"`):
147+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2" `, defaults to `"linear"`):
148148 The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
149- `linear` or `scaled_linear `.
149+ `linear`, `scaled_linear`, or `squaredcos_cap_v2 `.
150150 trained_betas (`np.ndarray`, *optional*):
151151 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
152- prediction_type (`str `, defaults to `epsilon`, *optional*):
152+ prediction_type (`"epsilon" `, `"sample"`, or `"v_prediction"`, defaults to `" epsilon" `, *optional*):
153153 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
154154 `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
155155 Video](https://huggingface.co/papers/2210.02303) paper).
156- timestep_spacing (`str `, defaults to `"linspace"`):
156+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing" `, defaults to `"linspace"`):
157157 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
158158 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
159159 steps_offset (`int`, defaults to 0):
@@ -173,13 +173,13 @@ def __init__(
173173 num_train_timesteps : int = 1000 ,
174174 beta_start : float = 0.0001 ,
175175 beta_end : float = 0.02 ,
176- beta_schedule : str = "linear" ,
176+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
177177 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
178- prediction_type : str = "epsilon" ,
179- timestep_spacing : str = "linspace" ,
178+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
179+ timestep_spacing : Literal [ "linspace" , "leading" , "trailing" ] = "linspace" ,
180180 steps_offset : int = 0 ,
181181 rescale_betas_zero_snr : bool = False ,
182- ):
182+ ) -> None :
183183 if trained_betas is not None :
184184 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
185185 elif beta_schedule == "linear" :
@@ -219,29 +219,29 @@ def __init__(
219219 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
220220
221221 @property
222- def init_noise_sigma (self ):
222+ def init_noise_sigma (self ) -> torch . Tensor :
223223 # standard deviation of the initial noise distribution
224224 if self .config .timestep_spacing in ["linspace" , "trailing" ]:
225225 return self .sigmas .max ()
226226
227227 return (self .sigmas .max () ** 2 + 1 ) ** 0.5
228228
229229 @property
230- def step_index (self ):
230+ def step_index (self ) -> Optional [ int ] :
231231 """
232232 The index counter for current timestep. It will increase 1 after each scheduler step.
233233 """
234234 return self ._step_index
235235
236236 @property
237- def begin_index (self ):
237+ def begin_index (self ) -> Optional [ int ] :
238238 """
239239 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
240240 """
241241 return self ._begin_index
242242
243243 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
244- def set_begin_index (self , begin_index : int = 0 ):
244+ def set_begin_index (self , begin_index : int = 0 ) -> None :
245245 """
246246 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
247247
@@ -259,7 +259,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
259259 Args:
260260 sample (`torch.Tensor`):
261261 The input sample.
262- timestep (`int`, *optional* ):
262+ timestep (`float` or `torch.Tensor` ):
263263 The current timestep in the diffusion chain.
264264
265265 Returns:
@@ -275,7 +275,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
275275 self .is_scale_input_called = True
276276 return sample
277277
278- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
278+ def set_timesteps (self , num_inference_steps : int , device : Optional [ Union [str , torch .device ]] = None ) -> None :
279279 """
280280 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
281281
@@ -381,13 +381,13 @@ def step(
381381 Args:
382382 model_output (`torch.Tensor`):
383383 The direct output from learned diffusion model.
384- timestep (`float`):
384+ timestep (`float` or `torch.Tensor` ):
385385 The current discrete timestep in the diffusion chain.
386386 sample (`torch.Tensor`):
387387 A current instance of a sample created by the diffusion process.
388388 generator (`torch.Generator`, *optional*):
389389 A random number generator.
390- return_dict (`bool`):
390+ return_dict (`bool`, defaults to `True` ):
391391 Whether or not to return a
392392 [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
393393
@@ -517,5 +517,5 @@ def add_noise(
517517 noisy_samples = original_samples + noise * sigma
518518 return noisy_samples
519519
520- def __len__ (self ):
520+ def __len__ (self ) -> int :
521521 return self .config .num_train_timesteps
0 commit comments