Skip to content

Commit 859b809

Browse files
authored
Improve docstrings and type hints in scheduling_euler_ancestral_discrete.py (#12766)
refactor: add type hints to methods and update docstrings for parameters.
1 parent d769d8a commit 859b809

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def alpha_bar_fn(t):
9494

9595

9696
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97-
def rescale_zero_terminal_snr(betas):
97+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
9898
"""
9999
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
100100
@@ -144,16 +144,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
144144
The starting `beta` value of inference.
145145
beta_end (`float`, defaults to 0.02):
146146
The final `beta` value.
147-
beta_schedule (`str`, defaults to `"linear"`):
147+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
148148
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
149-
`linear` or `scaled_linear`.
149+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
150150
trained_betas (`np.ndarray`, *optional*):
151151
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
152-
prediction_type (`str`, defaults to `epsilon`, *optional*):
152+
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
153153
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
154154
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
155155
Video](https://huggingface.co/papers/2210.02303) paper).
156-
timestep_spacing (`str`, defaults to `"linspace"`):
156+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
157157
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
158158
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
159159
steps_offset (`int`, defaults to 0):
@@ -173,13 +173,13 @@ def __init__(
173173
num_train_timesteps: int = 1000,
174174
beta_start: float = 0.0001,
175175
beta_end: float = 0.02,
176-
beta_schedule: str = "linear",
176+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
177177
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
178-
prediction_type: str = "epsilon",
179-
timestep_spacing: str = "linspace",
178+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
179+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
180180
steps_offset: int = 0,
181181
rescale_betas_zero_snr: bool = False,
182-
):
182+
) -> None:
183183
if trained_betas is not None:
184184
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
185185
elif beta_schedule == "linear":
@@ -219,29 +219,29 @@ def __init__(
219219
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
220220

221221
@property
222-
def init_noise_sigma(self):
222+
def init_noise_sigma(self) -> torch.Tensor:
223223
# standard deviation of the initial noise distribution
224224
if self.config.timestep_spacing in ["linspace", "trailing"]:
225225
return self.sigmas.max()
226226

227227
return (self.sigmas.max() ** 2 + 1) ** 0.5
228228

229229
@property
230-
def step_index(self):
230+
def step_index(self) -> Optional[int]:
231231
"""
232232
The index counter for current timestep. It will increase 1 after each scheduler step.
233233
"""
234234
return self._step_index
235235

236236
@property
237-
def begin_index(self):
237+
def begin_index(self) -> Optional[int]:
238238
"""
239239
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
240240
"""
241241
return self._begin_index
242242

243243
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
244-
def set_begin_index(self, begin_index: int = 0):
244+
def set_begin_index(self, begin_index: int = 0) -> None:
245245
"""
246246
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
247247
@@ -259,7 +259,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
259259
Args:
260260
sample (`torch.Tensor`):
261261
The input sample.
262-
timestep (`int`, *optional*):
262+
timestep (`float` or `torch.Tensor`):
263263
The current timestep in the diffusion chain.
264264
265265
Returns:
@@ -275,7 +275,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
275275
self.is_scale_input_called = True
276276
return sample
277277

278-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
278+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
279279
"""
280280
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
281281
@@ -381,13 +381,13 @@ def step(
381381
Args:
382382
model_output (`torch.Tensor`):
383383
The direct output from learned diffusion model.
384-
timestep (`float`):
384+
timestep (`float` or `torch.Tensor`):
385385
The current discrete timestep in the diffusion chain.
386386
sample (`torch.Tensor`):
387387
A current instance of a sample created by the diffusion process.
388388
generator (`torch.Generator`, *optional*):
389389
A random number generator.
390-
return_dict (`bool`):
390+
return_dict (`bool`, defaults to `True`):
391391
Whether or not to return a
392392
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
393393
@@ -517,5 +517,5 @@ def add_noise(
517517
noisy_samples = original_samples + noise * sigma
518518
return noisy_samples
519519

520-
def __len__(self):
520+
def __len__(self) -> int:
521521
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)