@@ -117,6 +117,7 @@ def loop(
117117 solver ,
118118 stepsize_controller ,
119119 discrete_terminating_event ,
120+ delays ,
120121 saveat ,
121122 t0 ,
122123 t1 ,
@@ -522,13 +523,15 @@ def _loop_backsolve_bwd(
522523 solver ,
523524 stepsize_controller ,
524525 discrete_terminating_event ,
526+ delays ,
525527 saveat ,
526528 t0 ,
527529 t1 ,
528530 dt0 ,
529531 max_steps ,
530532 throw ,
531533 init_state ,
534+ y0_history ,
532535):
533536 assert discrete_terminating_event is None
534537
@@ -566,6 +569,8 @@ def _loop_backsolve_bwd(
566569 adjoint = self ,
567570 solver = solver ,
568571 stepsize_controller = stepsize_controller ,
572+ discrete_terminating_event = discrete_terminating_event ,
573+ delays = delays ,
569574 terms = adjoint_terms ,
570575 dt0 = None if dt0 is None else - dt0 ,
571576 max_steps = max_steps ,
@@ -745,6 +750,7 @@ def loop(
745750 passed_solver_state ,
746751 passed_controller_state ,
747752 discrete_terminating_event ,
753+ delays ,
748754 ** kwargs ,
749755 ):
750756 if jtu .tree_structure (saveat .subs , is_leaf = _is_subsaveat ) != jtu .tree_structure (
@@ -790,6 +796,10 @@ def loop(
790796 raise NotImplementedError (
791797 "`diffrax.BacksolveAdjoint` is not compatible with events."
792798 )
799+ if delays is not None :
800+ raise NotImplementedError (
801+ "Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
802+ )
793803
794804 y = init_state .y
795805 init_state = eqx .tree_at (lambda s : s .y , init_state , object ())
@@ -804,6 +814,7 @@ def loop(
804814 init_state = init_state ,
805815 solver = solver ,
806816 discrete_terminating_event = discrete_terminating_event ,
817+ delays = delays ,
807818 ** kwargs ,
808819 )
809820 final_state = _only_transpose_ys (final_state )
0 commit comments