@@ -103,6 +103,7 @@ def _solve(
103103
104104 def cond_fn (val ):
105105 _ , step , diffsize , diffsize_prev = val
106+ at_least_two = step < 2
106107 rate = diffsize / diffsize_prev
107108 factor = diffsize * rate / (1 - rate )
108109 if self .max_steps is None :
@@ -112,7 +113,7 @@ def cond_fn(val):
112113 not_small = ~ _small (diffsize )
113114 not_diverged = ~ _diverged (rate )
114115 not_converged = ~ _converged (factor , self .kappa )
115- return step_okay & not_small & not_diverged & not_converged
116+ return at_least_two | ( step_okay & not_small & not_diverged & not_converged )
116117
117118 def body_fn (val ):
118119 flat , step , diffsize , _ = val
@@ -128,10 +129,7 @@ def body_fn(val):
128129 val = (flat , step + 1 , diffsize , diffsize_prev )
129130 return val
130131
131- # Unconditionally execute two loops to fill in diffsize and diffsize_prev.
132- val = (flat , 0 , None , None )
133- val = body_fn (val )
134- val = body_fn (val )
132+ val = (flat , 0 , 0.0 , 0.0 )
135133 val = lax .while_loop (cond_fn , body_fn , val )
136134 flat , num_steps , diffsize , diffsize_prev = val
137135
0 commit comments