55import equinox as eqx
66import equinox .internal as eqxi
77import jax
8- import jax .flatten_util as jfu
98import jax .lax as lax
109import jax .numpy as jnp
1110import jax .tree_util as jtu
@@ -218,7 +217,19 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta):
218217# numerical behaviour since the iteration is close to 0. (Although we have
219218# multiplied by the increment of the control, i.e. dt, which is small...)
220219def _implicit_relation_f (fi , nonlinear_solve_args ):
221- diagonal , vf , prod , ti , yi_partial , args , control = nonlinear_solve_args
220+ # We pass stage_index, even without using it, so that custom nonlinear solvers
221+ # can special-case on the stage if they want to.
222+ (
223+ stage_index ,
224+ diagonal ,
225+ vf ,
226+ prod ,
227+ ti ,
228+ yi_partial ,
229+ args ,
230+ control ,
231+ ) = nonlinear_solve_args
232+ del stage_index
222233 diff = (
223234 fi ** ω
224235 - vf (ti , (yi_partial ** ω + diagonal * prod (fi , control ) ** ω ).ω , args ) ** ω
@@ -231,7 +242,11 @@ def _implicit_relation_k(ki, nonlinear_solve_args):
231242 # c.f:
232243 # https://github.com/SciML/DiffEqDevMaterials/blob/master/newton/output/main.pdf
233244 # (Bearing in mind that our ki is dt times smaller than theirs.)
234- diagonal , vf_prod , ti , yi_partial , args , control = nonlinear_solve_args
245+ #
246+ # We pass stage_index, even without using it, so that custom nonlinear solvers
247+ # can special-case on the stage if they want to.
248+ stage_index , diagonal , vf_prod , ti , yi_partial , args , control = nonlinear_solve_args
249+ del stage_index
235250 diff = (
236251 ki ** ω
237252 - vf_prod (ti , (yi_partial ** ω + diagonal * ki ** ω ).ω , args , control ) ** ω
@@ -732,6 +747,25 @@ def embed_c(tab):
732747 implicit_predictor = jnp .asarray (implicit_predictor )
733748 implicit_c = get_implicit (tableaus_c )
734749
750+ if implicit_term is None :
751+ implicit_vf = _unused
752+ implicit_prod = _unused
753+ implicit_vf_prod = _unused
754+ else :
755+ if eval_fs :
756+ assert f0 is not _unused
757+ implicit_vf = eqx .filter_closure_convert (implicit_term .vf , t0 , y0 , args )
758+ implicit_prod = eqx .filter_closure_convert (
759+ implicit_term .prod , get_implicit (f0 ), implicit_control
760+ )
761+ implicit_vf_prod = _unused
762+ else :
763+ implicit_vf = _unused
764+ implicit_prod = _unused
765+ implicit_vf_prod = eqx .filter_closure_convert (
766+ implicit_term .vf_prod , t0 , y0 , args , implicit_control
767+ )
768+
735769 #
736770 # Run the loop over stages. (This is what you signed up for, and it's taken us
737771 # several hundred lines of code just to get this far!)
@@ -791,11 +825,10 @@ def rk_stage(val):
791825 f_pred = jtu .tree_map (if_first_stage , f0_for_jac , f_pred )
792826 assert f0 is not _unused
793827 f_implicit_args = (
828+ stage_index ,
794829 implicit_diagonal_i ,
795- eqx .filter_closure_convert (implicit_term .vf , t0 , y0 , args ),
796- eqx .filter_closure_convert (
797- implicit_term .prod , f_pred , implicit_control
798- ),
830+ implicit_vf ,
831+ implicit_prod ,
799832 implicit_ti ,
800833 yi_partial ,
801834 args ,
@@ -814,10 +847,9 @@ def rk_stage(val):
814847 # doesn't matter.
815848 k_pred = jtu .tree_map (if_first_stage , k0_for_jac , k_pred )
816849 k_implicit_args = (
850+ stage_index ,
817851 implicit_diagonal_i ,
818- eqx .filter_closure_convert (
819- implicit_term .vf_prod , t0 , y0 , args , implicit_control
820- ),
852+ implicit_vf_prod ,
821853 implicit_ti ,
822854 yi_partial ,
823855 args ,
@@ -954,23 +986,41 @@ def buffers(val):
954986 # For DIRK and SDIRK methods then the choice here doesn't matter; we compute
955987 # the Jacobian straight away.
956988 # For ESDIRK methods, this is the Jacobian of an explicit step.
957- #
958- # TODO: fix once we have more advanced nonlinear solvers.
959- # Mildly hacky hardcoding for now.
960989 if eval_fs :
961990 assert f0 is not _unused
962- struct = jax .eval_shape (lambda : jfu .ravel_pytree (get_implicit (f0 ))[0 ])
963- jac_f = (
964- jnp .eye (struct .size , dtype = struct .dtype ),
965- jnp .arange (struct .size , dtype = jnp .int32 ),
991+ f_implicit_args = (
992+ jnp .array (0 ),
993+ # zero diagonal == identity matrix as the Jacobian
994+ jnp .array (0.0 , dtype = implicit_diagonal .dtype ),
995+ implicit_vf ,
996+ implicit_prod ,
997+ t0 ,
998+ y0 ,
999+ args ,
1000+ implicit_control ,
1001+ )
1002+ jac_f = self .nonlinear_solver .jac (
1003+ _implicit_relation_f ,
1004+ jtu .tree_map (jnp .zeros_like , get_implicit (f0 )),
1005+ _filter_stop_gradient (f_implicit_args ),
9661006 )
9671007 jac_k = _unused
9681008 else :
969- struct = jax .eval_shape (lambda : jfu .ravel_pytree (y0 )[0 ])
1009+ k_implicit_args = (
1010+ jnp .array (0 ),
1011+ # zero diagonal == identity matrix as the Jacobian
1012+ jnp .array (0.0 , dtype = implicit_diagonal .dtype ),
1013+ implicit_vf_prod ,
1014+ t0 ,
1015+ y0 ,
1016+ args ,
1017+ implicit_control ,
1018+ )
9701019 jac_f = _unused
971- jac_k = (
972- jnp .eye (struct .size , dtype = struct .dtype ),
973- jnp .arange (struct .size , dtype = jnp .int32 ),
1020+ jac_k = self .nonlinear_solver .jac (
1021+ _implicit_relation_k ,
1022+ jtu .tree_map (jnp .zeros_like , y0 ),
1023+ _filter_stop_gradient (k_implicit_args ),
9741024 )
9751025 init_val = (
9761026 init_stage_index ,
@@ -1065,7 +1115,7 @@ def __init_subclass__(cls, **kwargs):
10651115 diagonal = cls .tableau .a_diagonal [0 ]
10661116 assert (cls .tableau .a_diagonal == diagonal ).all ()
10671117
1068- calculate_jacobian = CalculateJacobian .second_stage
1118+ calculate_jacobian = CalculateJacobian .first_stage
10691119
10701120
10711121class AbstractESDIRK (AbstractDIRK ):
0 commit comments