From 239537f448320baa890a829277052799e12512d0 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Fri, 5 Dec 2025 19:45:29 +0100 Subject: [PATCH] feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler --- .../scheduling_dpmsolver_singlestep.py | 124 +++++++++++------- 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 55c9fb6e7384..4916e1abb549 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -86,42 +86,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): + trained_betas (`np.ndarray` or `List[float]`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): + solver_order (`int`, defaults to `2`): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). + `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): + dynamic_thresholding_ratio (`float`, defaults to `0.995`): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): + sample_max_value (`float`, defaults to `1.0`): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): + solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): + lower_order_final (`bool`, defaults to `False`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): @@ -132,15 +132,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to `1.0`): + The flow shift parameter for flow-based models. + final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. + variance_type (`"learned"` or `"learned_range"`, *optional*): + Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's + output contains the predicted Gaussian variance. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the noise schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shifting to apply. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -152,27 +160,27 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, + variance_type: Optional[Literal["learned", "learned_range"]] = None, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", - ): + time_shift_type: Literal["exponential"] = "exponential", + ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: @@ -242,6 +250,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + + Returns: + `List[int]`: + The list of solver orders for each timestep. """ steps = num_inference_steps order = self.config.solver_order @@ -276,21 +288,29 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: return orders @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -302,19 +322,21 @@ def set_begin_index(self, begin_index: int = 0): def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is @@ -453,7 +475,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -490,7 +512,7 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -512,7 +534,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -637,7 +659,7 @@ def convert_model_output( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -733,7 +755,7 @@ def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -797,7 +819,7 @@ def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -908,7 +930,7 @@ def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1030,8 +1052,8 @@ def singlestep_dpm_solver_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, - order: int = None, + sample: Optional[torch.Tensor] = None, + order: Optional[int] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1125,7 +1147,7 @@ def index_for_timestep( return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -1146,7 +1168,7 @@ def step( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -1156,11 +1178,13 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + generator (`torch.Generator`, *optional*): + A random number generator for stochastic sampling. + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -1277,5 +1301,5 @@ def add_noise( noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps