Skip to content

Conversation

@poonai
Copy link

@poonai poonai commented Nov 20, 2025

Hello,

I'm new to JAX and numerical computing, and willing to invest the time to learn by implementing numerical methods. After this, I plan to add support for DAE and additional rosenbrock methods. I would appreciate your guidance on getting this PR merged.

Thanks

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! This is super awesome. I've been meaning to add Rosenbrock methods for a while, so I'd love to get this in. I have some fairly nitty comments but the structure of this PR already looks excellent.

control = terms.contr(t0, t1)

# common L.H.S
A = (lx.MatrixLinearOperator(eye) / (control * self.tableau.γ[0])) - (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use lx.IdentityLinearOperator here.

class Ros3p(AbstractAdaptiveSolver):
r"""Ros3p method.
3rd order Rosenbrock method for solving stiff equation. Uses a 1st order local linear
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it may make more sense to use a third-order hermite interpolation by default. (Which is the usual standard interpolation method most of the time.)

)
)

# stage 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this might be a bit tricky, but it's quite important that the stages be wrapped up into a lax.scan. This is so that we don't compile the user-supplied vector field multiple times, as that hugely increases compile time for nontrivial vector fields.

solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
del made_jump, solver_state
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not checked, is this method definitely not FSAL?

f(1.0)


def test_ros3p():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a few tests that run pretty much every solver, it would be good to add ros3p to these as well.



_tableau = _RosenbrockTableau(
m_sol=jnp.array([2.0, 0.5773502691896258, 0.4226497308103742]),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be regular numpy arrays, to avoid initialising the JAX backend (which happens the first time an array is created) whilst Diffrax is being imported.


def step(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional work is required to make it work with MultiTerm. Reading about other rosenbrock method will allow me to design the proper PyTree abstraction. So, I've limited the term structure to the simple ode.

I can implement this now or include it in the next PR along with the next Rosenbrock method.

# and Heun if the solver is Stratonovich.
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
@pytest.mark.parametrize("dtype", (jnp.float64,))
@pytest.mark.skip(reason="This test is failing in the main the branch")
Copy link
Author

@poonai poonai Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been fixed in the dev branch.

Should I raise the PR against the dev branch?

I changed the base branch to dev.

@poonai
Copy link
Author

poonai commented Nov 28, 2025

@patrick-kidger I've addressed your comments. Please review them when you get a chance. In the meantime, I'll start reading about other methods.

Thanks

@poonai poonai changed the base branch from main to dev November 28, 2025 11:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants