-
-
Notifications
You must be signed in to change notification settings - Fork 163
add support for ros3p rosenbrock method #709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there! This is super awesome. I've been meaning to add Rosenbrock methods for a while, so I'd love to get this in. I have some fairly nitty comments but the structure of this PR already looks excellent.
diffrax/_solver/ros3p.py
Outdated
| control = terms.contr(t0, t1) | ||
|
|
||
| # common L.H.S | ||
| A = (lx.MatrixLinearOperator(eye) / (control * self.tableau.γ[0])) - ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use lx.IdentityLinearOperator here.
diffrax/_solver/ros3p.py
Outdated
| class Ros3p(AbstractAdaptiveSolver): | ||
| r"""Ros3p method. | ||
| 3rd order Rosenbrock method for solving stiff equation. Uses a 1st order local linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it may make more sense to use a third-order hermite interpolation by default. (Which is the usual standard interpolation method most of the time.)
diffrax/_solver/ros3p.py
Outdated
| ) | ||
| ) | ||
|
|
||
| # stage 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this might be a bit tricky, but it's quite important that the stages be wrapped up into a lax.scan. This is so that we don't compile the user-supplied vector field multiple times, as that hugely increases compile time for nontrivial vector fields.
diffrax/_solver/ros3p.py
Outdated
| solver_state: _SolverState, | ||
| made_jump: BoolScalarLike, | ||
| ) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]: | ||
| del made_jump, solver_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not checked, is this method definitely not FSAL?
| f(1.0) | ||
|
|
||
|
|
||
| def test_ros3p(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a few tests that run pretty much every solver, it would be good to add ros3p to these as well.
diffrax/_solver/ros3p.py
Outdated
|
|
||
|
|
||
| _tableau = _RosenbrockTableau( | ||
| m_sol=jnp.array([2.0, 0.5773502691896258, 0.4226497308103742]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these should be regular numpy arrays, to avoid initialising the JAX backend (which happens the first time an array is created) whilst Diffrax is being imported.
|
|
||
| def step( | ||
| self, | ||
| terms: AbstractTerm[ArrayLike, ArrayLike], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional work is required to make it work with MultiTerm. Reading about other rosenbrock method will allow me to design the proper PyTree abstraction. So, I've limited the term structure to the simple ode.
I can implement this now or include it in the next PR along with the next Rosenbrock method.
test/test_sde1.py
Outdated
| # and Heun if the solver is Stratonovich. | ||
| @pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders()) | ||
| @pytest.mark.parametrize("dtype", (jnp.float64,)) | ||
| @pytest.mark.skip(reason="This test is failing in the main the branch") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test has been fixed in the dev branch.
Should I raise the PR against the dev branch?
I changed the base branch to dev.
|
@patrick-kidger I've addressed your comments. Please review them when you get a chance. In the meantime, I'll start reading about other methods. Thanks |
Hello,
I'm new to JAX and numerical computing, and willing to invest the time to learn by implementing numerical methods. After this, I plan to add support for DAE and additional rosenbrock methods. I would appreciate your guidance on getting this PR merged.
Thanks