Skip to content

Commit a6a9fb4

Browse files
committed
add magcache support with calibration mode
1 parent cbf4b5e commit a6a9fb4

File tree

2 files changed

+279
-88
lines changed

2 files changed

+279
-88
lines changed

src/diffusers/hooks/mag_cache.py

Lines changed: 121 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Optional, Tuple, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import numpy as np
1919
import torch
@@ -30,7 +30,8 @@
3030
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
3131
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
3232

33-
# Default Mag Ratios for Flux models (Dev/Schnell)
33+
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
34+
# Users must explicitly pass these to the config if using Flux.
3435
# Reference: https://github.com/Zehong-Ma/MagCache
3536
FLUX_MAG_RATIOS = np.array(
3637
[1.0]
@@ -97,38 +98,62 @@ class MagCacheConfig:
9798
num_inference_steps (`int`, defaults to `28`):
9899
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
99100
mag_ratios (`np.ndarray`, *optional*):
100-
The pre-computed magnitude ratios for the model. If not provided, defaults to the Flux ratios.
101+
The pre-computed magnitude ratios for the model. These are checkpoint-dependent.
102+
If not provided, you must set `calibrate=True` to calculate them for your specific model.
103+
For Flux models, you can use `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
104+
calibrate (`bool`, defaults to `False`):
105+
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates
106+
the magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios`
107+
for new models or schedulers.
101108
"""
102109

103110
threshold: float = 0.24
104111
max_skip_steps: int = 5
105112
retention_ratio: float = 0.1
106113
num_inference_steps: int = 28
107114
mag_ratios: Optional[np.ndarray] = None
115+
calibrate: bool = False
108116

109117
def __post_init__(self):
110-
if self.mag_ratios is None:
111-
self.mag_ratios = FLUX_MAG_RATIOS
112-
113-
if len(self.mag_ratios) != self.num_inference_steps:
114-
logger.debug(
115-
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
118+
# Strict validation: User MUST provide ratios OR enable calibration.
119+
if self.mag_ratios is None and not self.calibrate:
120+
raise ValueError(
121+
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
122+
"To get them for your model:\n"
123+
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
124+
"2. Run inference on your model once.\n"
125+
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
126+
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
116127
)
117-
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
128+
129+
if not self.calibrate and self.mag_ratios is not None:
130+
if len(self.mag_ratios) != self.num_inference_steps:
131+
logger.debug(
132+
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
133+
)
134+
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
118135

119136

120137
class MagCacheState(BaseState):
121138
def __init__(self) -> None:
122139
super().__init__()
140+
# Cache for the residual (output - input) from the *previous* timestep
123141
self.previous_residual: torch.Tensor = None
124142

143+
# State inputs/outputs for the current forward pass
125144
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
126145
self.should_compute: bool = True
127146

147+
# MagCache accumulators
128148
self.accumulated_ratio: float = 1.0
129149
self.accumulated_err: float = 0.0
130150
self.accumulated_steps: int = 0
151+
152+
# Current step counter (timestep index)
131153
self.step_index: int = 0
154+
155+
# Calibration storage
156+
self.calibration_ratios: List[float] = []
132157

133158
def reset(self):
134159
self.previous_residual = None
@@ -137,6 +162,7 @@ def reset(self):
137162
self.accumulated_err = 0.0
138163
self.accumulated_steps = 0
139164
self.step_index = 0
165+
self.calibration_ratios = []
140166

141167

142168
class MagCacheHeadHook(ModelHook):
@@ -153,36 +179,42 @@ def initialize_hook(self, module):
153179
return module
154180

155181
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
182+
# Capture input hidden_states
156183
hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
157184

158185
state: MagCacheState = self.state_manager.get_state()
159186
state.head_block_input = hidden_states
160187

161188
should_compute = True
162189

163-
current_step = state.step_index
164-
if current_step >= len(self.config.mag_ratios):
165-
current_scale = 1.0
190+
if self.config.calibrate:
191+
# Never skip during calibration
192+
should_compute = True
166193
else:
167-
current_scale = self.config.mag_ratios[current_step]
194+
# MagCache Logic
195+
current_step = state.step_index
196+
if current_step >= len(self.config.mag_ratios):
197+
current_scale = 1.0
198+
else:
199+
current_scale = self.config.mag_ratios[current_step]
168200

169-
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
201+
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
170202

171-
if current_step >= retention_step:
172-
state.accumulated_ratio *= current_scale
173-
state.accumulated_steps += 1
174-
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
203+
if current_step >= retention_step:
204+
state.accumulated_ratio *= current_scale
205+
state.accumulated_steps += 1
206+
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
175207

176-
if (
177-
state.previous_residual is not None
178-
and state.accumulated_err <= self.config.threshold
179-
and state.accumulated_steps <= self.config.max_skip_steps
180-
):
181-
should_compute = False
182-
else:
183-
state.accumulated_ratio = 1.0
184-
state.accumulated_steps = 0
185-
state.accumulated_err = 0.0
208+
if (
209+
state.previous_residual is not None
210+
and state.accumulated_err <= self.config.threshold
211+
and state.accumulated_steps <= self.config.max_skip_steps
212+
):
213+
should_compute = False
214+
else:
215+
state.accumulated_ratio = 1.0
216+
state.accumulated_steps = 0
217+
state.accumulated_err = 0.0
186218

187219
state.should_compute = should_compute
188220

@@ -193,6 +225,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
193225
output = hidden_states
194226
res = state.previous_residual
195227

228+
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
196229
if res.shape == output.shape:
197230
output = output + res
198231
elif (
@@ -201,6 +234,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
201234
and output.shape[0] == res.shape[0]
202235
and output.shape[2] == res.shape[2]
203236
):
237+
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
204238
diff = output.shape[1] - res.shape[1]
205239
if diff > 0:
206240
output = output.clone()
@@ -220,20 +254,18 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
220254
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
221255
"encoder_hidden_states", args, kwargs
222256
)
223-
224257
max_idx = max(
225258
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
226259
)
227260
ret_list = [None] * (max_idx + 1)
228-
229261
ret_list[self._metadata.return_hidden_states_index] = output
230262
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
231-
232263
return tuple(ret_list)
233264
else:
234265
return output
235266

236267
else:
268+
# Compute original forward
237269
output = self.fn_ref.original_forward(*args, **kwargs)
238270
return output
239271

@@ -260,21 +292,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
260292

261293
if not state.should_compute:
262294
hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
263-
264295
if self.is_tail:
265-
state.step_index += 1
266-
if state.step_index >= self.config.num_inference_steps:
267-
state.step_index = 0
268-
state.accumulated_ratio = 1.0
269-
state.accumulated_steps = 0
270-
state.accumulated_err = 0.0
271-
state.previous_residual = None
296+
# Still need to advance step index even if we skip
297+
self._advance_step(state)
272298

273299
if self._metadata.return_encoder_hidden_states_index is not None:
274300
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
275301
"encoder_hidden_states", args, kwargs
276302
)
277-
278303
max_idx = max(
279304
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
280305
)
@@ -285,38 +310,71 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
285310

286311
return hidden_states
287312

288-
289313
output = self.fn_ref.original_forward(*args, **kwargs)
290314

291315
if self.is_tail:
316+
# Calculate residual for next steps
292317
if isinstance(output, tuple):
293318
out_hidden = output[self._metadata.return_hidden_states_index]
294319
else:
295320
out_hidden = output
296321

297322
in_hidden = state.head_block_input
298-
323+
324+
# Determine residual
299325
if out_hidden.shape == in_hidden.shape:
300326
residual = out_hidden - in_hidden
301327
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
302328
diff = in_hidden.shape[1] - out_hidden.shape[1]
303329
if diff == 0:
304330
residual = out_hidden - in_hidden
305331
else:
306-
residual = out_hidden - in_hidden
332+
residual = out_hidden - in_hidden # Fallback to matching tail
333+
else:
334+
# Fallback for completely mismatched shapes
335+
residual = out_hidden # Invalid but prevents crash
307336

308-
state.previous_residual = residual
337+
if self.config.calibrate:
338+
self._perform_calibration_step(state, residual)
309339

310-
state.step_index += 1
311-
if state.step_index >= self.config.num_inference_steps:
312-
state.step_index = 0
313-
state.accumulated_ratio = 1.0
314-
state.accumulated_steps = 0
315-
state.accumulated_err = 0.0
316-
state.previous_residual = None
340+
state.previous_residual = residual
341+
self._advance_step(state)
317342

318343
return output
319344

345+
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
346+
if state.previous_residual is None:
347+
# First step has no previous residual to compare against.
348+
# We log 1.0 as a neutral starting point.
349+
ratio = 1.0
350+
else:
351+
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
352+
# norm(dim=-1) gives magnitude of each token vector
353+
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
354+
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
355+
356+
# Avoid division by zero
357+
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
358+
359+
state.calibration_ratios.append(ratio)
360+
361+
def _advance_step(self, state: MagCacheState):
362+
state.step_index += 1
363+
if state.step_index >= self.config.num_inference_steps:
364+
# End of inference loop
365+
if self.config.calibrate:
366+
print(f"\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
367+
print(f"{state.calibration_ratios}\n")
368+
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
369+
370+
# Reset state
371+
state.step_index = 0
372+
state.accumulated_ratio = 1.0
373+
state.accumulated_steps = 0
374+
state.accumulated_err = 0.0
375+
state.previous_residual = None
376+
state.calibration_ratios = []
377+
320378

321379
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
322380
"""
@@ -331,7 +389,6 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
331389
state_manager = StateManager(MagCacheState, (), {})
332390
remaining_blocks = []
333391

334-
# Identify blocks
335392
for name, submodule in module.named_children():
336393
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
337394
continue
@@ -342,6 +399,16 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
342399
logger.warning("MagCache: No transformer blocks found to apply hooks.")
343400
return
344401

402+
if len(remaining_blocks) == 1:
403+
# Single block case: It acts as both Head (Decision) and Tail (Residual Calc)
404+
name, block = remaining_blocks[0]
405+
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
406+
# Apply BlockHook (Tail) FIRST so it is the INNER wrapper
407+
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
408+
# Apply HeadHook SECOND so it is the OUTER wrapper (controls flow)
409+
_apply_mag_cache_head_hook(block, state_manager, config)
410+
return
411+
345412
head_block_name, head_block = remaining_blocks.pop(0)
346413
tail_block_name, tail_block = remaining_blocks.pop(-1)
347414

@@ -371,4 +438,4 @@ def _apply_mag_cache_block_hook(
371438
) -> None:
372439
registry = HookRegistry.check_if_exists_or_initialize(block)
373440
hook = MagCacheBlockHook(state_manager, is_tail, config)
374-
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
441+
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)

0 commit comments

Comments
 (0)