Skip to content

Commit d769d8a

Browse files
authored
Improve docstrings and type hints in scheduling_heun_discrete.py (#12726)
refactor: improve type hints for `beta_schedule`, `prediction_type`, and `timestep_spacing` parameters, and add return type hints to several methods.
1 parent c25582d commit d769d8a

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
107107
The starting `beta` value of inference.
108108
beta_end (`float`, defaults to 0.02):
109109
The final `beta` value.
110-
beta_schedule (`str`, defaults to `"linear"`):
110+
beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"exp"`, defaults to `"linear"`):
111111
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
112-
`linear` or `scaled_linear`.
112+
`linear`, `scaled_linear`, `squaredcos_cap_v2`, or `exp`.
113113
trained_betas (`np.ndarray`, *optional*):
114114
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
115-
prediction_type (`str`, defaults to `epsilon`, *optional*):
115+
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
116116
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
117117
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
118118
Video](https://huggingface.co/papers/2210.02303) paper).
@@ -128,7 +128,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
128128
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
129129
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
130130
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
131-
timestep_spacing (`str`, defaults to `"linspace"`):
131+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
132132
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
133133
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
134134
steps_offset (`int`, defaults to 0):
@@ -144,17 +144,17 @@ def __init__(
144144
num_train_timesteps: int = 1000,
145145
beta_start: float = 0.00085, # sensible defaults
146146
beta_end: float = 0.012,
147-
beta_schedule: str = "linear",
147+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "exp"] = "linear",
148148
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
149-
prediction_type: str = "epsilon",
149+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
150150
use_karras_sigmas: Optional[bool] = False,
151151
use_exponential_sigmas: Optional[bool] = False,
152152
use_beta_sigmas: Optional[bool] = False,
153153
clip_sample: Optional[bool] = False,
154154
clip_sample_range: float = 1.0,
155-
timestep_spacing: str = "linspace",
155+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
156156
steps_offset: int = 0,
157-
):
157+
) -> None:
158158
if self.config.use_beta_sigmas and not is_scipy_available():
159159
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
160160
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -241,7 +241,7 @@ def begin_index(self):
241241
return self._begin_index
242242

243243
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
244-
def set_begin_index(self, begin_index: int = 0):
244+
def set_begin_index(self, begin_index: int = 0) -> None:
245245
"""
246246
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
247247
@@ -263,7 +263,7 @@ def scale_model_input(
263263
Args:
264264
sample (`torch.Tensor`):
265265
The input sample.
266-
timestep (`int`, *optional*):
266+
timestep (`float` or `torch.Tensor`):
267267
The current timestep in the diffusion chain.
268268
269269
Returns:
@@ -283,19 +283,19 @@ def set_timesteps(
283283
device: Union[str, torch.device] = None,
284284
num_train_timesteps: Optional[int] = None,
285285
timesteps: Optional[List[int]] = None,
286-
):
286+
) -> None:
287287
"""
288288
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
289289
290290
Args:
291-
num_inference_steps (`int`):
291+
num_inference_steps (`int`, *optional*, defaults to `None`):
292292
The number of diffusion steps used when generating samples with a pre-trained model.
293-
device (`str` or `torch.device`, *optional*):
293+
device (`str`, `torch.device`, *optional*, defaults to `None`):
294294
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
295-
num_train_timesteps (`int`, *optional*):
295+
num_train_timesteps (`int`, *optional*, defaults to `None`):
296296
The number of diffusion steps used when training the model. If `None`, the default
297297
`num_train_timesteps` attribute is used.
298-
timesteps (`List[int]`, *optional*):
298+
timesteps (`List[int]`, *optional*, defaults to `None`):
299299
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
300300
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
301301
must be `None`, and `timestep_spacing` attribute will be ignored.
@@ -370,7 +370,7 @@ def set_timesteps(
370370
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
371371

372372
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
373-
def _sigma_to_t(self, sigma, log_sigmas):
373+
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
374374
"""
375375
Convert sigma values to corresponding timestep values through interpolation.
376376
@@ -407,7 +407,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
407407
return t
408408

409409
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
410-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
410+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
411411
"""
412412
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
413413
Models](https://huggingface.co/papers/2206.00364).
@@ -700,5 +700,5 @@ def add_noise(
700700
noisy_samples = original_samples + noise * sigma
701701
return noisy_samples
702702

703-
def __len__(self):
703+
def __len__(self) -> int:
704704
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)