@@ -100,10 +100,19 @@ def retrieve_timesteps(
100100 `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101101 second element is the number of inference steps.
102102 """
103+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
104+ accepts_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
105+
103106 if timesteps is not None and sigmas is not None :
104- raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
105- if timesteps is not None :
106- accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
107+ if not accepts_timesteps and not accepts_sigmas :
108+ raise ValueError (
109+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
110+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111+ )
112+ scheduler .set_timesteps (timesteps = timesteps , sigmas = sigmas , device = device , ** kwargs )
113+ timesteps = scheduler .timesteps
114+ num_inference_steps = len (timesteps )
115+ elif timesteps is not None and sigmas is None :
107116 if not accepts_timesteps :
108117 raise ValueError (
109118 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
@@ -112,9 +121,8 @@ def retrieve_timesteps(
112121 scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
113122 timesteps = scheduler .timesteps
114123 num_inference_steps = len (timesteps )
115- elif sigmas is not None :
116- accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
117- if not accept_sigmas :
124+ elif timesteps is None and sigmas is not None :
125+ if not accepts_sigmas :
118126 raise ValueError (
119127 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
120128 f" sigmas schedules. Please check whether you are using the correct scheduler."
0 commit comments