1313# limitations under the License.
1414
1515from dataclasses import dataclass
16- from typing import Optional , Tuple , Union
16+ from typing import List , Optional , Tuple , Union
1717
1818import numpy as np
1919import torch
3030_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
3131_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
3232
33- # Default Mag Ratios for Flux models (Dev/Schnell)
33+ # Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
34+ # Users must explicitly pass these to the config if using Flux.
3435# Reference: https://github.com/Zehong-Ma/MagCache
3536FLUX_MAG_RATIOS = np .array (
3637 [1.0 ]
@@ -97,38 +98,62 @@ class MagCacheConfig:
9798 num_inference_steps (`int`, defaults to `28`):
9899 The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
99100 mag_ratios (`np.ndarray`, *optional*):
100- The pre-computed magnitude ratios for the model. If not provided, defaults to the Flux ratios.
101+ The pre-computed magnitude ratios for the model. These are checkpoint-dependent.
102+ If not provided, you must set `calibrate=True` to calculate them for your specific model.
103+ For Flux models, you can use `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
104+ calibrate (`bool`, defaults to `False`):
105+ If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates
106+ the magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios`
107+ for new models or schedulers.
101108 """
102109
103110 threshold : float = 0.24
104111 max_skip_steps : int = 5
105112 retention_ratio : float = 0.1
106113 num_inference_steps : int = 28
107114 mag_ratios : Optional [np .ndarray ] = None
115+ calibrate : bool = False
108116
109117 def __post_init__ (self ):
110- if self .mag_ratios is None :
111- self .mag_ratios = FLUX_MAG_RATIOS
112-
113- if len (self .mag_ratios ) != self .num_inference_steps :
114- logger .debug (
115- f"Interpolating mag_ratios from length { len (self .mag_ratios )} to { self .num_inference_steps } "
118+ # Strict validation: User MUST provide ratios OR enable calibration.
119+ if self .mag_ratios is None and not self .calibrate :
120+ raise ValueError (
121+ " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n "
122+ "To get them for your model:\n "
123+ "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n "
124+ "2. Run inference on your model once.\n "
125+ "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n "
126+ "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
116127 )
117- self .mag_ratios = nearest_interp (self .mag_ratios , self .num_inference_steps )
128+
129+ if not self .calibrate and self .mag_ratios is not None :
130+ if len (self .mag_ratios ) != self .num_inference_steps :
131+ logger .debug (
132+ f"Interpolating mag_ratios from length { len (self .mag_ratios )} to { self .num_inference_steps } "
133+ )
134+ self .mag_ratios = nearest_interp (self .mag_ratios , self .num_inference_steps )
118135
119136
120137class MagCacheState (BaseState ):
121138 def __init__ (self ) -> None :
122139 super ().__init__ ()
140+ # Cache for the residual (output - input) from the *previous* timestep
123141 self .previous_residual : torch .Tensor = None
124142
143+ # State inputs/outputs for the current forward pass
125144 self .head_block_input : Union [torch .Tensor , Tuple [torch .Tensor , ...]] = None
126145 self .should_compute : bool = True
127146
147+ # MagCache accumulators
128148 self .accumulated_ratio : float = 1.0
129149 self .accumulated_err : float = 0.0
130150 self .accumulated_steps : int = 0
151+
152+ # Current step counter (timestep index)
131153 self .step_index : int = 0
154+
155+ # Calibration storage
156+ self .calibration_ratios : List [float ] = []
132157
133158 def reset (self ):
134159 self .previous_residual = None
@@ -137,6 +162,7 @@ def reset(self):
137162 self .accumulated_err = 0.0
138163 self .accumulated_steps = 0
139164 self .step_index = 0
165+ self .calibration_ratios = []
140166
141167
142168class MagCacheHeadHook (ModelHook ):
@@ -153,36 +179,42 @@ def initialize_hook(self, module):
153179 return module
154180
155181 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
182+ # Capture input hidden_states
156183 hidden_states = self ._metadata ._get_parameter_from_args_kwargs ("hidden_states" , args , kwargs )
157184
158185 state : MagCacheState = self .state_manager .get_state ()
159186 state .head_block_input = hidden_states
160187
161188 should_compute = True
162189
163- current_step = state . step_index
164- if current_step >= len ( self . config . mag_ratios ):
165- current_scale = 1.0
190+ if self . config . calibrate :
191+ # Never skip during calibration
192+ should_compute = True
166193 else :
167- current_scale = self .config .mag_ratios [current_step ]
194+ # MagCache Logic
195+ current_step = state .step_index
196+ if current_step >= len (self .config .mag_ratios ):
197+ current_scale = 1.0
198+ else :
199+ current_scale = self .config .mag_ratios [current_step ]
168200
169- retention_step = int (self .config .retention_ratio * self .config .num_inference_steps + 0.5 )
201+ retention_step = int (self .config .retention_ratio * self .config .num_inference_steps + 0.5 )
170202
171- if current_step >= retention_step :
172- state .accumulated_ratio *= current_scale
173- state .accumulated_steps += 1
174- state .accumulated_err += abs (1.0 - state .accumulated_ratio )
203+ if current_step >= retention_step :
204+ state .accumulated_ratio *= current_scale
205+ state .accumulated_steps += 1
206+ state .accumulated_err += abs (1.0 - state .accumulated_ratio )
175207
176- if (
177- state .previous_residual is not None
178- and state .accumulated_err <= self .config .threshold
179- and state .accumulated_steps <= self .config .max_skip_steps
180- ):
181- should_compute = False
182- else :
183- state .accumulated_ratio = 1.0
184- state .accumulated_steps = 0
185- state .accumulated_err = 0.0
208+ if (
209+ state .previous_residual is not None
210+ and state .accumulated_err <= self .config .threshold
211+ and state .accumulated_steps <= self .config .max_skip_steps
212+ ):
213+ should_compute = False
214+ else :
215+ state .accumulated_ratio = 1.0
216+ state .accumulated_steps = 0
217+ state .accumulated_err = 0.0
186218
187219 state .should_compute = should_compute
188220
@@ -193,6 +225,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
193225 output = hidden_states
194226 res = state .previous_residual
195227
228+ # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
196229 if res .shape == output .shape :
197230 output = output + res
198231 elif (
@@ -201,6 +234,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
201234 and output .shape [0 ] == res .shape [0 ]
202235 and output .shape [2 ] == res .shape [2 ]
203236 ):
237+ # Assuming concatenation where image part is at the end (standard in Flux/SD3)
204238 diff = output .shape [1 ] - res .shape [1 ]
205239 if diff > 0 :
206240 output = output .clone ()
@@ -220,20 +254,18 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
220254 original_encoder_hidden_states = self ._metadata ._get_parameter_from_args_kwargs (
221255 "encoder_hidden_states" , args , kwargs
222256 )
223-
224257 max_idx = max (
225258 self ._metadata .return_hidden_states_index , self ._metadata .return_encoder_hidden_states_index
226259 )
227260 ret_list = [None ] * (max_idx + 1 )
228-
229261 ret_list [self ._metadata .return_hidden_states_index ] = output
230262 ret_list [self ._metadata .return_encoder_hidden_states_index ] = original_encoder_hidden_states
231-
232263 return tuple (ret_list )
233264 else :
234265 return output
235266
236267 else :
268+ # Compute original forward
237269 output = self .fn_ref .original_forward (* args , ** kwargs )
238270 return output
239271
@@ -260,21 +292,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
260292
261293 if not state .should_compute :
262294 hidden_states = self ._metadata ._get_parameter_from_args_kwargs ("hidden_states" , args , kwargs )
263-
264295 if self .is_tail :
265- state .step_index += 1
266- if state .step_index >= self .config .num_inference_steps :
267- state .step_index = 0
268- state .accumulated_ratio = 1.0
269- state .accumulated_steps = 0
270- state .accumulated_err = 0.0
271- state .previous_residual = None
296+ # Still need to advance step index even if we skip
297+ self ._advance_step (state )
272298
273299 if self ._metadata .return_encoder_hidden_states_index is not None :
274300 encoder_hidden_states = self ._metadata ._get_parameter_from_args_kwargs (
275301 "encoder_hidden_states" , args , kwargs
276302 )
277-
278303 max_idx = max (
279304 self ._metadata .return_hidden_states_index , self ._metadata .return_encoder_hidden_states_index
280305 )
@@ -285,38 +310,71 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
285310
286311 return hidden_states
287312
288-
289313 output = self .fn_ref .original_forward (* args , ** kwargs )
290314
291315 if self .is_tail :
316+ # Calculate residual for next steps
292317 if isinstance (output , tuple ):
293318 out_hidden = output [self ._metadata .return_hidden_states_index ]
294319 else :
295320 out_hidden = output
296321
297322 in_hidden = state .head_block_input
298-
323+
324+ # Determine residual
299325 if out_hidden .shape == in_hidden .shape :
300326 residual = out_hidden - in_hidden
301327 elif out_hidden .ndim == 3 and in_hidden .ndim == 3 and out_hidden .shape [2 ] == in_hidden .shape [2 ]:
302328 diff = in_hidden .shape [1 ] - out_hidden .shape [1 ]
303329 if diff == 0 :
304330 residual = out_hidden - in_hidden
305331 else :
306- residual = out_hidden - in_hidden
332+ residual = out_hidden - in_hidden # Fallback to matching tail
333+ else :
334+ # Fallback for completely mismatched shapes
335+ residual = out_hidden # Invalid but prevents crash
307336
308- state .previous_residual = residual
337+ if self .config .calibrate :
338+ self ._perform_calibration_step (state , residual )
309339
310- state .step_index += 1
311- if state .step_index >= self .config .num_inference_steps :
312- state .step_index = 0
313- state .accumulated_ratio = 1.0
314- state .accumulated_steps = 0
315- state .accumulated_err = 0.0
316- state .previous_residual = None
340+ state .previous_residual = residual
341+ self ._advance_step (state )
317342
318343 return output
319344
345+ def _perform_calibration_step (self , state : MagCacheState , current_residual : torch .Tensor ):
346+ if state .previous_residual is None :
347+ # First step has no previous residual to compare against.
348+ # We log 1.0 as a neutral starting point.
349+ ratio = 1.0
350+ else :
351+ # MagCache Calibration Formula: mean(norm(curr) / norm(prev))
352+ # norm(dim=-1) gives magnitude of each token vector
353+ curr_norm = torch .linalg .norm (current_residual .float (), dim = - 1 )
354+ prev_norm = torch .linalg .norm (state .previous_residual .float (), dim = - 1 )
355+
356+ # Avoid division by zero
357+ ratio = (curr_norm / (prev_norm + 1e-8 )).mean ().item ()
358+
359+ state .calibration_ratios .append (ratio )
360+
361+ def _advance_step (self , state : MagCacheState ):
362+ state .step_index += 1
363+ if state .step_index >= self .config .num_inference_steps :
364+ # End of inference loop
365+ if self .config .calibrate :
366+ print (f"\n [MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):" )
367+ print (f"{ state .calibration_ratios } \n " )
368+ logger .info (f"MagCache Calibration Results: { state .calibration_ratios } " )
369+
370+ # Reset state
371+ state .step_index = 0
372+ state .accumulated_ratio = 1.0
373+ state .accumulated_steps = 0
374+ state .accumulated_err = 0.0
375+ state .previous_residual = None
376+ state .calibration_ratios = []
377+
320378
321379def apply_mag_cache (module : torch .nn .Module , config : MagCacheConfig ) -> None :
322380 """
@@ -331,7 +389,6 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
331389 state_manager = StateManager (MagCacheState , (), {})
332390 remaining_blocks = []
333391
334- # Identify blocks
335392 for name , submodule in module .named_children ():
336393 if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance (submodule , torch .nn .ModuleList ):
337394 continue
@@ -342,6 +399,16 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
342399 logger .warning ("MagCache: No transformer blocks found to apply hooks." )
343400 return
344401
402+ if len (remaining_blocks ) == 1 :
403+ # Single block case: It acts as both Head (Decision) and Tail (Residual Calc)
404+ name , block = remaining_blocks [0 ]
405+ logger .info (f"MagCache: Applying Head+Tail Hooks to single block '{ name } '" )
406+ # Apply BlockHook (Tail) FIRST so it is the INNER wrapper
407+ _apply_mag_cache_block_hook (block , state_manager , config , is_tail = True )
408+ # Apply HeadHook SECOND so it is the OUTER wrapper (controls flow)
409+ _apply_mag_cache_head_hook (block , state_manager , config )
410+ return
411+
345412 head_block_name , head_block = remaining_blocks .pop (0 )
346413 tail_block_name , tail_block = remaining_blocks .pop (- 1 )
347414
@@ -371,4 +438,4 @@ def _apply_mag_cache_block_hook(
371438) -> None :
372439 registry = HookRegistry .check_if_exists_or_initialize (block )
373440 hook = MagCacheBlockHook (state_manager , is_tail , config )
374- registry .register_hook (hook , _MAG_CACHE_BLOCK_HOOK )
441+ registry .register_hook (hook , _MAG_CACHE_BLOCK_HOOK )
0 commit comments