Skip to content

Commit a8a57c6

Browse files
committed
formatting
1 parent 8dbb673 commit a8a57c6

File tree

1 file changed

+0
-8
lines changed

1 file changed

+0
-8
lines changed

src/diffusers/hooks/mag_cache.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
162162

163163
current_step = state.step_index
164164
if current_step >= len(self.config.mag_ratios):
165-
# Safety fallback if steps exceed config
166165
current_scale = 1.0
167166
else:
168167
current_scale = self.config.mag_ratios[current_step]
@@ -174,16 +173,13 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
174173
state.accumulated_steps += 1
175174
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
176175

177-
# Check skip condition
178-
# We must have a previous residual to skip
179176
if (
180177
state.previous_residual is not None
181178
and state.accumulated_err <= self.config.threshold
182179
and state.accumulated_steps <= self.config.max_skip_steps
183180
):
184181
should_compute = False
185182
else:
186-
# Reset accumulators if we decide to compute (and we are past retention)
187183
state.accumulated_ratio = 1.0
188184
state.accumulated_steps = 0
189185
state.accumulated_err = 0.0
@@ -207,7 +203,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
207203
):
208204
diff = output.shape[1] - res.shape[1]
209205
if diff > 0:
210-
# Add residual to the end
211206
output = output.clone()
212207
output[:, diff:, :] = output[:, diff:, :] + res
213208
else:
@@ -239,7 +234,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
239234
return output
240235

241236
else:
242-
# Run original forward
243237
output = self.fn_ref.original_forward(*args, **kwargs)
244238
return output
245239

@@ -302,15 +296,13 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
302296

303297
in_hidden = state.head_block_input
304298

305-
# Calculate residual
306299
if out_hidden.shape == in_hidden.shape:
307300
residual = out_hidden - in_hidden
308301
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
309302
diff = in_hidden.shape[1] - out_hidden.shape[1]
310303
if diff == 0:
311304
residual = out_hidden - in_hidden
312305
else:
313-
# Fallback: Just calculate residual on matching tail (Image part usually at end)
314306
residual = out_hidden - in_hidden
315307

316308
state.previous_residual = residual

0 commit comments

Comments
 (0)