@@ -77,7 +77,7 @@ def alpha_bar_fn(t):
7777
7878
7979# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
80- def rescale_zero_terminal_snr (betas ) :
80+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
8181 """
8282 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
8383
@@ -127,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127127 The starting `beta` value of inference.
128128 beta_end (`float`, defaults to 0.02):
129129 The final `beta` value.
130- beta_schedule (`str `, defaults to `"linear"`):
130+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2" `, defaults to `"linear"`):
131131 The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
132132 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
133133 trained_betas (`np.ndarray`, *optional*):
134134 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
135- solver_order (`int`, default `2`):
135+ solver_order (`int`, defaults to `2`):
136136 The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
137137 due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
138138 unconditional sampling.
139- prediction_type (`str `, defaults to `epsilon`, *optional*):
139+ prediction_type (`"epsilon" `, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `" epsilon" `, *optional*):
140140 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
141- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
142- Video](https://huggingface.co/papers/2210.02303) paper).
141+ `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
142+ Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction` .
143143 thresholding (`bool`, defaults to `False`):
144144 Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
145145 as Stable Diffusion.
@@ -149,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
149149 The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
150150 predict_x0 (`bool`, defaults to `True`):
151151 Whether to use the updating algorithm on the predicted x0.
152- solver_type (`str`, default ` bh2`):
152+ solver_type (`"bh1"` or `"bh2"`, defaults to `" bh2" `):
153153 Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
154154 otherwise.
155155 lower_order_final (`bool`, default `True`):
@@ -171,12 +171,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
171171 Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
172172 use_flow_sigmas (`bool`, *optional*, defaults to `False`):
173173 Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
174- timestep_spacing (`str `, defaults to `"linspace"`):
174+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing" `, defaults to `"linspace"`):
175175 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
176176 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
177177 steps_offset (`int`, defaults to 0):
178178 An offset added to the inference steps, as required by some model families.
179- final_sigmas_type (`str `, defaults to `"zero"`):
179+ final_sigmas_type (`"zero"` or `"sigma_min" `, defaults to `"zero"`):
180180 The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181181 sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182182 rescale_betas_zero_snr (`bool`, defaults to `False`):
@@ -194,30 +194,30 @@ def __init__(
194194 num_train_timesteps : int = 1000 ,
195195 beta_start : float = 0.0001 ,
196196 beta_end : float = 0.02 ,
197- beta_schedule : str = "linear" ,
197+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
198198 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
199199 solver_order : int = 2 ,
200- prediction_type : str = "epsilon" ,
200+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" , "flow_prediction" ] = "epsilon" ,
201201 thresholding : bool = False ,
202202 dynamic_thresholding_ratio : float = 0.995 ,
203203 sample_max_value : float = 1.0 ,
204204 predict_x0 : bool = True ,
205- solver_type : str = "bh2" ,
205+ solver_type : Literal [ "bh1" , "bh2" ] = "bh2" ,
206206 lower_order_final : bool = True ,
207207 disable_corrector : List [int ] = [],
208- solver_p : SchedulerMixin = None ,
208+ solver_p : Optional [ SchedulerMixin ] = None ,
209209 use_karras_sigmas : Optional [bool ] = False ,
210210 use_exponential_sigmas : Optional [bool ] = False ,
211211 use_beta_sigmas : Optional [bool ] = False ,
212212 use_flow_sigmas : Optional [bool ] = False ,
213213 flow_shift : Optional [float ] = 1.0 ,
214- timestep_spacing : str = "linspace" ,
214+ timestep_spacing : Literal [ "linspace" , "leading" , "trailing" ] = "linspace" ,
215215 steps_offset : int = 0 ,
216- final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
216+ final_sigmas_type : Optional [Literal [ "zero" , "sigma_min" ]] = "zero" ,
217217 rescale_betas_zero_snr : bool = False ,
218218 use_dynamic_shifting : bool = False ,
219- time_shift_type : str = "exponential" ,
220- ):
219+ time_shift_type : Literal [ "exponential" ] = "exponential" ,
220+ ) -> None :
221221 if self .config .use_beta_sigmas and not is_scipy_available ():
222222 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
223223 if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
@@ -279,21 +279,21 @@ def __init__(
279279 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
280280
281281 @property
282- def step_index (self ):
282+ def step_index (self ) -> Optional [ int ] :
283283 """
284284 The index counter for current timestep. It will increase 1 after each scheduler step.
285285 """
286286 return self ._step_index
287287
288288 @property
289- def begin_index (self ):
289+ def begin_index (self ) -> Optional [ int ] :
290290 """
291291 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
292292 """
293293 return self ._begin_index
294294
295295 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
296- def set_begin_index (self , begin_index : int = 0 ):
296+ def set_begin_index (self , begin_index : int = 0 ) -> None :
297297 """
298298 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
299299
@@ -304,8 +304,8 @@ def set_begin_index(self, begin_index: int = 0):
304304 self ._begin_index = begin_index
305305
306306 def set_timesteps (
307- self , num_inference_steps : int , device : Union [str , torch .device ] = None , mu : Optional [float ] = None
308- ):
307+ self , num_inference_steps : int , device : Optional [ Union [str , torch .device ] ] = None , mu : Optional [float ] = None
308+ ) -> None :
309309 """
310310 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
311311
@@ -314,6 +314,8 @@ def set_timesteps(
314314 The number of diffusion steps used when generating samples with a pre-trained model.
315315 device (`str` or `torch.device`, *optional*):
316316 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
317+ mu (`float`, *optional*):
318+ Optional mu parameter for dynamic shifting when using exponential time shift type.
317319 """
318320 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
319321 if mu is not None :
@@ -475,7 +477,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
475477 return sample
476478
477479 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
478- def _sigma_to_t (self , sigma , log_sigmas ) :
480+ def _sigma_to_t (self , sigma : np . ndarray , log_sigmas : np . ndarray ) -> np . ndarray :
479481 """
480482 Convert sigma values to corresponding timestep values through interpolation.
481483
@@ -512,7 +514,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
512514 return t
513515
514516 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
515- def _sigma_to_alpha_sigma_t (self , sigma ) :
517+ def _sigma_to_alpha_sigma_t (self , sigma : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor ] :
516518 """
517519 Convert sigma values to alpha_t and sigma_t values.
518520
@@ -534,7 +536,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
534536 return alpha_t , sigma_t
535537
536538 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
537- def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
539+ def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
538540 """
539541 Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
540542 Models](https://huggingface.co/papers/2206.00364).
@@ -1030,7 +1032,7 @@ def index_for_timestep(
10301032 return step_index
10311033
10321034 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
1033- def _init_step_index (self , timestep ) :
1035+ def _init_step_index (self , timestep : Union [ int , torch . Tensor ]) -> None :
10341036 """
10351037 Initialize the step_index counter for the scheduler.
10361038
@@ -1060,11 +1062,11 @@ def step(
10601062 Args:
10611063 model_output (`torch.Tensor`):
10621064 The direct output from learned diffusion model.
1063- timestep (`int`):
1065+ timestep (`int` or `torch.Tensor` ):
10641066 The current discrete timestep in the diffusion chain.
10651067 sample (`torch.Tensor`):
10661068 A current instance of a sample created by the diffusion process.
1067- return_dict (`bool`):
1069+ return_dict (`bool`, defaults to `True` ):
10681070 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
10691071
10701072 Returns:
@@ -1192,5 +1194,5 @@ def add_noise(
11921194 noisy_samples = alpha_t * original_samples + sigma_t * noise
11931195 return noisy_samples
11941196
1195- def __len__ (self ):
1197+ def __len__ (self ) -> int :
11961198 return self .config .num_train_timesteps
0 commit comments