Skip to content

Commit 2842c14

Browse files
authored
Improve docstrings and type hints in scheduling_unipc_multistep.py (#12767)
refactor: add type hints and update docstrings for UniPCMultistepScheduler parameters and methods.
1 parent c318686 commit 2842c14

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def alpha_bar_fn(t):
7777

7878

7979
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
80-
def rescale_zero_terminal_snr(betas):
80+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
8181
"""
8282
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
8383
@@ -127,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127127
The starting `beta` value of inference.
128128
beta_end (`float`, defaults to 0.02):
129129
The final `beta` value.
130-
beta_schedule (`str`, defaults to `"linear"`):
130+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
131131
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
132132
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
133133
trained_betas (`np.ndarray`, *optional*):
134134
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
135-
solver_order (`int`, default `2`):
135+
solver_order (`int`, defaults to `2`):
136136
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
137137
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
138138
unconditional sampling.
139-
prediction_type (`str`, defaults to `epsilon`, *optional*):
139+
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*):
140140
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
141-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
142-
Video](https://huggingface.co/papers/2210.02303) paper).
141+
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
142+
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
143143
thresholding (`bool`, defaults to `False`):
144144
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
145145
as Stable Diffusion.
@@ -149,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
149149
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
150150
predict_x0 (`bool`, defaults to `True`):
151151
Whether to use the updating algorithm on the predicted x0.
152-
solver_type (`str`, default `bh2`):
152+
solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`):
153153
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
154154
otherwise.
155155
lower_order_final (`bool`, default `True`):
@@ -171,12 +171,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
171171
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
172172
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
173173
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
174-
timestep_spacing (`str`, defaults to `"linspace"`):
174+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
175175
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
176176
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
177177
steps_offset (`int`, defaults to 0):
178178
An offset added to the inference steps, as required by some model families.
179-
final_sigmas_type (`str`, defaults to `"zero"`):
179+
final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`):
180180
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181181
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182182
rescale_betas_zero_snr (`bool`, defaults to `False`):
@@ -194,30 +194,30 @@ def __init__(
194194
num_train_timesteps: int = 1000,
195195
beta_start: float = 0.0001,
196196
beta_end: float = 0.02,
197-
beta_schedule: str = "linear",
197+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
198198
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
199199
solver_order: int = 2,
200-
prediction_type: str = "epsilon",
200+
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
201201
thresholding: bool = False,
202202
dynamic_thresholding_ratio: float = 0.995,
203203
sample_max_value: float = 1.0,
204204
predict_x0: bool = True,
205-
solver_type: str = "bh2",
205+
solver_type: Literal["bh1", "bh2"] = "bh2",
206206
lower_order_final: bool = True,
207207
disable_corrector: List[int] = [],
208-
solver_p: SchedulerMixin = None,
208+
solver_p: Optional[SchedulerMixin] = None,
209209
use_karras_sigmas: Optional[bool] = False,
210210
use_exponential_sigmas: Optional[bool] = False,
211211
use_beta_sigmas: Optional[bool] = False,
212212
use_flow_sigmas: Optional[bool] = False,
213213
flow_shift: Optional[float] = 1.0,
214-
timestep_spacing: str = "linspace",
214+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
215215
steps_offset: int = 0,
216-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
216+
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
217217
rescale_betas_zero_snr: bool = False,
218218
use_dynamic_shifting: bool = False,
219-
time_shift_type: str = "exponential",
220-
):
219+
time_shift_type: Literal["exponential"] = "exponential",
220+
) -> None:
221221
if self.config.use_beta_sigmas and not is_scipy_available():
222222
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
223223
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -279,21 +279,21 @@ def __init__(
279279
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
280280

281281
@property
282-
def step_index(self):
282+
def step_index(self) -> Optional[int]:
283283
"""
284284
The index counter for current timestep. It will increase 1 after each scheduler step.
285285
"""
286286
return self._step_index
287287

288288
@property
289-
def begin_index(self):
289+
def begin_index(self) -> Optional[int]:
290290
"""
291291
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
292292
"""
293293
return self._begin_index
294294

295295
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
296-
def set_begin_index(self, begin_index: int = 0):
296+
def set_begin_index(self, begin_index: int = 0) -> None:
297297
"""
298298
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
299299
@@ -304,8 +304,8 @@ def set_begin_index(self, begin_index: int = 0):
304304
self._begin_index = begin_index
305305

306306
def set_timesteps(
307-
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
308-
):
307+
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
308+
) -> None:
309309
"""
310310
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
311311
@@ -314,6 +314,8 @@ def set_timesteps(
314314
The number of diffusion steps used when generating samples with a pre-trained model.
315315
device (`str` or `torch.device`, *optional*):
316316
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
317+
mu (`float`, *optional*):
318+
Optional mu parameter for dynamic shifting when using exponential time shift type.
317319
"""
318320
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
319321
if mu is not None:
@@ -475,7 +477,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
475477
return sample
476478

477479
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
478-
def _sigma_to_t(self, sigma, log_sigmas):
480+
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
479481
"""
480482
Convert sigma values to corresponding timestep values through interpolation.
481483
@@ -512,7 +514,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
512514
return t
513515

514516
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
515-
def _sigma_to_alpha_sigma_t(self, sigma):
517+
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
516518
"""
517519
Convert sigma values to alpha_t and sigma_t values.
518520
@@ -534,7 +536,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
534536
return alpha_t, sigma_t
535537

536538
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
537-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
539+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
538540
"""
539541
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
540542
Models](https://huggingface.co/papers/2206.00364).
@@ -1030,7 +1032,7 @@ def index_for_timestep(
10301032
return step_index
10311033

10321034
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
1033-
def _init_step_index(self, timestep):
1035+
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
10341036
"""
10351037
Initialize the step_index counter for the scheduler.
10361038
@@ -1060,11 +1062,11 @@ def step(
10601062
Args:
10611063
model_output (`torch.Tensor`):
10621064
The direct output from learned diffusion model.
1063-
timestep (`int`):
1065+
timestep (`int` or `torch.Tensor`):
10641066
The current discrete timestep in the diffusion chain.
10651067
sample (`torch.Tensor`):
10661068
A current instance of a sample created by the diffusion process.
1067-
return_dict (`bool`):
1069+
return_dict (`bool`, defaults to `True`):
10681070
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
10691071
10701072
Returns:
@@ -1192,5 +1194,5 @@ def add_noise(
11921194
noisy_samples = alpha_t * original_samples + sigma_t * noise
11931195
return noisy_samples
11941196

1195-
def __len__(self):
1197+
def __len__(self) -> int:
11961198
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)