@@ -162,7 +162,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
162162
163163 current_step = state .step_index
164164 if current_step >= len (self .config .mag_ratios ):
165- # Safety fallback if steps exceed config
166165 current_scale = 1.0
167166 else :
168167 current_scale = self .config .mag_ratios [current_step ]
@@ -174,16 +173,13 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
174173 state .accumulated_steps += 1
175174 state .accumulated_err += abs (1.0 - state .accumulated_ratio )
176175
177- # Check skip condition
178- # We must have a previous residual to skip
179176 if (
180177 state .previous_residual is not None
181178 and state .accumulated_err <= self .config .threshold
182179 and state .accumulated_steps <= self .config .max_skip_steps
183180 ):
184181 should_compute = False
185182 else :
186- # Reset accumulators if we decide to compute (and we are past retention)
187183 state .accumulated_ratio = 1.0
188184 state .accumulated_steps = 0
189185 state .accumulated_err = 0.0
@@ -207,7 +203,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
207203 ):
208204 diff = output .shape [1 ] - res .shape [1 ]
209205 if diff > 0 :
210- # Add residual to the end
211206 output = output .clone ()
212207 output [:, diff :, :] = output [:, diff :, :] + res
213208 else :
@@ -239,7 +234,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
239234 return output
240235
241236 else :
242- # Run original forward
243237 output = self .fn_ref .original_forward (* args , ** kwargs )
244238 return output
245239
@@ -302,15 +296,13 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
302296
303297 in_hidden = state .head_block_input
304298
305- # Calculate residual
306299 if out_hidden .shape == in_hidden .shape :
307300 residual = out_hidden - in_hidden
308301 elif out_hidden .ndim == 3 and in_hidden .ndim == 3 and out_hidden .shape [2 ] == in_hidden .shape [2 ]:
309302 diff = in_hidden .shape [1 ] - out_hidden .shape [1 ]
310303 if diff == 0 :
311304 residual = out_hidden - in_hidden
312305 else :
313- # Fallback: Just calculate residual on matching tail (Image part usually at end)
314306 residual = out_hidden - in_hidden
315307
316308 state .previous_residual = residual
0 commit comments