11import abc
22import functools as ft
33import warnings
4- from typing import Any , Optional
4+ from collections .abc import Iterable
5+ from typing import Any , Optional , Union
56
67import equinox as eqx
78import equinox .internal as eqxi
89import jax
910import jax .lax as lax
1011import jax .numpy as jnp
1112import jax .tree_util as jtu
13+ import lineax as lx
14+ import optimistix .internal as optxi
1215from equinox .internal import ω
1316
14- from ._ad import implicit_jvp
1517from ._heuristics import is_sde , is_unsafe_sde
1618from ._saveat import save_y , SaveAt , SubSaveAt
1719from ._solver import AbstractItoSolver , AbstractRungeKutta , AbstractStratonovichSolver
@@ -384,7 +386,7 @@ def loop(
384386 return final_state
385387
386388
387- def _vf (ys , residual , args__terms , closure ):
389+ def _vf (ys , residual , inputs ):
388390 state_no_y , _ = residual
389391 t = state_no_y .tprev
390392
@@ -393,14 +395,12 @@ def _unpack(_y):
393395 return _y1
394396
395397 y = jtu .tree_map (_unpack , ys )
396- args , terms = args__terms
397- _ , _ , solver , _ , _ = closure
398+ args , terms , _ , _ , solver , _ , _ = inputs
398399 return solver .func (terms , t , y , args )
399400
400401
401- def _solve (args__terms , closure ):
402- args , terms = args__terms
403- self , kwargs , solver , saveat , init_state = closure
402+ def _solve (inputs ):
403+ args , terms , self , kwargs , solver , saveat , init_state = inputs
404404 final_state , aux_stats = self ._loop (
405405 ** kwargs ,
406406 args = args ,
@@ -423,6 +423,15 @@ def _solve(args__terms, closure):
423423 )
424424
425425
426+ def _frozenset (x : Union [object , Iterable [object ]]) -> frozenset [object ]:
427+ try :
428+ iter_x = iter (x ) # pyright: ignore
429+ except TypeError :
430+ return frozenset ([x ])
431+ else :
432+ return frozenset (iter_x )
433+
434+
426435class ImplicitAdjoint (AbstractAdjoint ):
427436 r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
428437
@@ -433,8 +442,16 @@ class ImplicitAdjoint(AbstractAdjoint):
433442 the solver and instead directly compute
434443 $\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$
435444 via the implicit function theorem.
445+
446+ Observe that this involves solving a linear system with matrix given by the Jacobian
447+ `df/dy`.
436448 """ # noqa: E501
437449
450+ linear_solver : lx .AbstractLinearSolver = lx .AutoLinearSolver (well_posed = None )
451+ tags : frozenset [object ] = eqx .field (
452+ default_factory = frozenset , converter = _frozenset , static = True
453+ )
454+
438455 def loop (
439456 self ,
440457 * ,
@@ -459,8 +476,10 @@ def loop(
459476 init_state = _nondiff_solver_controller_state (
460477 self , init_state , passed_solver_state , passed_controller_state
461478 )
462- closure = (self , kwargs , solver , saveat , init_state )
463- ys , residual = implicit_jvp (_solve , _vf , (args , terms ), closure )
479+ inputs = (args , terms , self , kwargs , solver , saveat , init_state )
480+ ys , residual = optxi .implicit_jvp (
481+ _solve , _vf , inputs , self .tags , self .linear_solver
482+ )
464483
465484 final_state_no_ys , aux_stats = residual
466485 # Note that `final_state.save_state` has type PyTree[SaveState]. To access `.ys`
@@ -473,6 +492,15 @@ def loop(
473492 return final_state , aux_stats
474493
475494
495+ ImplicitAdjoint .__init__ .__doc__ = """**Arguments:**
496+
497+ - `linear_solver`: A [Lineax](https://github.com/google/lineax) solver for solving the
498+ linear system.
499+ - `tags`: Any Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
500+ Jacobian matrix `df/dy`.
501+ """
502+
503+
476504# Compute derivatives with respect to the first argument:
477505# - y, corresponding to the initial state;
478506# - args, corresponding to explicit parameters;
0 commit comments