Skip to content

Commit 0ab618e

Browse files
Improved compile times when using implicit solvers. On a benchmark problem this went from 20 seconds to 12 seconds.
1 parent cd1011e commit 0ab618e

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

diffrax/nonlinear_solver/newton.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _solve(
103103

104104
def cond_fn(val):
105105
_, step, diffsize, diffsize_prev = val
106+
at_least_two = step < 2
106107
rate = diffsize / diffsize_prev
107108
factor = diffsize * rate / (1 - rate)
108109
if self.max_steps is None:
@@ -112,7 +113,7 @@ def cond_fn(val):
112113
not_small = ~_small(diffsize)
113114
not_diverged = ~_diverged(rate)
114115
not_converged = ~_converged(factor, self.kappa)
115-
return step_okay & not_small & not_diverged & not_converged
116+
return at_least_two | (step_okay & not_small & not_diverged & not_converged)
116117

117118
def body_fn(val):
118119
flat, step, diffsize, _ = val
@@ -128,10 +129,7 @@ def body_fn(val):
128129
val = (flat, step + 1, diffsize, diffsize_prev)
129130
return val
130131

131-
# Unconditionally execute two loops to fill in diffsize and diffsize_prev.
132-
val = (flat, 0, None, None)
133-
val = body_fn(val)
134-
val = body_fn(val)
132+
val = (flat, 0, 0.0, 0.0)
135133
val = lax.while_loop(cond_fn, body_fn, val)
136134
flat, num_steps, diffsize, diffsize_prev = val
137135

0 commit comments

Comments
 (0)