From 241df7901d84009ec7830cdd9fb678b2d433d999 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sat, 13 Sep 2025 22:53:39 +0000 Subject: [PATCH 1/4] make a start --- diffsol/src/error.rs | 2 + diffsol/src/ode_equations/mod.rs | 32 +- diffsol/src/ode_solver/bdf.rs | 4 + diffsol/src/ode_solver/explicit_rk.rs | 9 + diffsol/src/ode_solver/explicit_sde_rk.rs | 390 ++++++++++++++++++++++ diffsol/src/ode_solver/method.rs | 2 +- diffsol/src/ode_solver/mod.rs | 1 + diffsol/src/ode_solver/runge_kutta.rs | 2 +- diffsol/src/ode_solver/sdirk.rs | 7 +- 9 files changed, 428 insertions(+), 21 deletions(-) create mode 100644 diffsol/src/ode_solver/explicit_sde_rk.rs diff --git a/diffsol/src/error.rs b/diffsol/src/error.rs index 79ee5e4a..0437ed0a 100644 --- a/diffsol/src/error.rs +++ b/diffsol/src/error.rs @@ -68,6 +68,8 @@ pub enum OdeSolverError { StopTimeBeforeCurrentTime { stop_time: f64, state_time: f64 }, #[error("Mass matrix not supported for this solver")] MassMatrixNotSupported, + #[error("Stochastic RHS term not supported for this solver")] + StochNotSupported, #[error("Stop time is at the current state time")] StopTimeAtCurrentTime, #[error("Interpolation time is after current time")] diff --git a/diffsol/src/ode_equations/mod.rs b/diffsol/src/ode_equations/mod.rs index 53cec666..dd401639 100644 --- a/diffsol/src/ode_equations/mod.rs +++ b/diffsol/src/ode_equations/mod.rs @@ -197,6 +197,7 @@ pub trait OdeEquationsRef<'a, ImplicitBounds: Sealed = Bounds<&'a Self>>: Op { type Root: NonLinearOp; type Init: ConstantOp; type Out: NonLinearOp; + type Stoch: StochOp; } impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T { @@ -205,6 +206,7 @@ impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T { type Root = >::Root; type Init = >::Init; type Out = >::Out; + type Stoch = >::Stoch; } // seal the trait so that users must use the provided default type for ImplicitBounds @@ -235,7 +237,9 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> { fn rhs(&self) -> >::Rhs; /// returns the mass matrix `M` as a [LinearOp] - fn mass(&self) -> Option<>::Mass>; + fn mass(&self) -> Option<>::Mass> { + None + } /// returns the root function `G(t, y)` as a [NonLinearOp] fn root(&self) -> Option<>::Root> { @@ -247,6 +251,10 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> { None } + fn stoch(&self) -> Option<>::Stoch> { + None + } + /// returns the initial condition, i.e. `y(t)`, where `t` is the initial time fn init(&self) -> >::Init; @@ -307,7 +315,11 @@ impl OdeEquations for &'_ T { fn out(&self) -> Option<>::Out> { (*self).out() } - + + fn stoch(&self) -> Option<>::Stoch> { + (*self).stoch() + } + fn init(&self) -> >::Init { (*self).init() } @@ -331,22 +343,6 @@ impl OdeEquationsImplicit for T where { } -pub trait OdeEquationsStoch: - OdeEquations< - Rhs: NonLinearOp - + StochOp, -> -{ -} - -impl OdeEquationsStoch for T where - T: OdeEquations< - Rhs: NonLinearOp - + StochOp, - > -{ -} - pub trait OdeEquationsSens: OdeEquations< Rhs: NonLinearOpSens, diff --git a/diffsol/src/ode_solver/bdf.rs b/diffsol/src/ode_solver/bdf.rs index 43c1abbe..5a2f179b 100644 --- a/diffsol/src/ode_solver/bdf.rs +++ b/diffsol/src/ode_solver/bdf.rs @@ -198,6 +198,10 @@ where mut nonlinear_solver: Nls, integrate_main_eqn: bool, ) -> Result { + // check that there isn't any diffusion term + if problem.eqn.stoch().is_some() { + return Err(DiffsolError::from(OdeSolverError::StochNotSupported)); + } // kappa values for difference orders, taken from Table 1 of [1] let kappa = [ Eqn::T::from(0.0), diff --git a/diffsol/src/ode_solver/explicit_rk.rs b/diffsol/src/ode_solver/explicit_rk.rs index c68c6538..3aa09cee 100644 --- a/diffsol/src/ode_solver/explicit_rk.rs +++ b/diffsol/src/ode_solver/explicit_rk.rs @@ -1,6 +1,7 @@ use super::method::AugmentedOdeSolverMethod; use super::runge_kutta::Rk; use crate::error::DiffsolError; +use crate::error::OdeSolverError; use crate::ode_solver::bdf::BdfStatistics; use crate::vector::VectorRef; use crate::NoAug; @@ -81,6 +82,10 @@ where tableau: Tableau, ) -> Result { Rk::::check_explicit_rk(problem, &tableau)?; + // check that there isn't any diffusion term + if problem.eqn.stoch().is_some() { + return Err(DiffsolError::from(OdeSolverError::StochNotSupported)); + } Ok(Self { rk: Rk::new(problem, state, tableau)?, augmented_eqn: None, @@ -94,6 +99,10 @@ where augmented_eqn: AugmentedEqn, ) -> Result { Rk::::check_explicit_rk(problem, &tableau)?; + // check that there isn't any diffusion term + if problem.eqn.stoch().is_some() { + return Err(DiffsolError::from(OdeSolverError::StochNotSupported)); + } Ok(Self { rk: Rk::new_augmented(problem, state, tableau, &augmented_eqn)?, augmented_eqn: Some(augmented_eqn), diff --git a/diffsol/src/ode_solver/explicit_sde_rk.rs b/diffsol/src/ode_solver/explicit_sde_rk.rs new file mode 100644 index 00000000..0742beb4 --- /dev/null +++ b/diffsol/src/ode_solver/explicit_sde_rk.rs @@ -0,0 +1,390 @@ +use super::method::AugmentedOdeSolverMethod; +use super::runge_kutta::Rk; +use crate::error::DiffsolError; +use crate::ode_solver::bdf::BdfStatistics; +use crate::vector::VectorRef; +use crate::NoAug; +use crate::OdeSolverStopReason; +use crate::RkState; +use crate::Tableau; +use crate::{ + AugmentedOdeEquations, DefaultDenseMatrix, DenseMatrix, OdeEquations, OdeSolverMethod, + OdeSolverProblem, OdeSolverState, Op, StateRef, StateRefMut, +}; +use num_traits::One; + + +/// An explicit Runge-Kutta method for SDEs. +/// +/// The particular method is defined by the [Tableau] used to create the solver. +/// If the `beta` matrix of the [Tableau] is present this is used for interpolation, otherwise hermite interpolation is used. +/// +/// Restrictions: +/// - The upper triangular and diagonal parts of the `a` matrix must be zero (i.e. explicit). +/// - The last row of the `a` matrix must be the same as the `b` vector, and the last element of the `c` vector must be 1 (i.e. a stiffly accurate method) +pub struct ExplicitSdeRk< + 'a, + Eqn, + M = <::V as DefaultDenseMatrix>::M, + AugmentedEqn = NoAug, +> where + Eqn: OdeEquations, + M: DenseMatrix, + Eqn::V: DefaultDenseMatrix, + AugmentedEqn: AugmentedOdeEquations, +{ + rk: Rk<'a, Eqn, M>, + augmented_eqn: Option, +} + +impl Clone for ExplicitSdeRk<'_, Eqn, M, AugmentedEqn> +where + Eqn: OdeEquations, + M: DenseMatrix, + AugmentedEqn: AugmentedOdeEquations, + Eqn::V: DefaultDenseMatrix, +{ + fn clone(&self) -> Self { + Self { + rk: self.rk.clone(), + augmented_eqn: self.augmented_eqn.clone(), + } + } +} + +impl<'a, Eqn, M, AugmentedEqn> ExplicitSdeRk<'a, Eqn, M, AugmentedEqn> +where + Eqn: OdeEquations, + M: DenseMatrix, + AugmentedEqn: AugmentedOdeEquations, + Eqn::V: DefaultDenseMatrix, +{ + pub fn new( + problem: &'a OdeSolverProblem, + state: RkState, + tableau: Tableau, + ) -> Result { + Rk::::check_explicit_rk(problem, &tableau)?; + Ok(Self { + rk: Rk::new(problem, state, tableau)?, + augmented_eqn: None, + }) + } + + pub fn new_augmented( + problem: &'a OdeSolverProblem, + state: RkState, + tableau: Tableau, + augmented_eqn: AugmentedEqn, + ) -> Result { + Rk::::check_explicit_rk(problem, &tableau)?; + Ok(Self { + rk: Rk::new_augmented(problem, state, tableau, &augmented_eqn)?, + augmented_eqn: Some(augmented_eqn), + }) + } + + pub fn get_statistics(&self) -> &BdfStatistics { + self.rk.get_statistics() + } +} + +impl<'a, Eqn, M, AugmentedEqn> OdeSolverMethod<'a, Eqn> for ExplicitSdeRk<'a, Eqn, M, AugmentedEqn> +where + Eqn: OdeEquations, + M: DenseMatrix, + AugmentedEqn: AugmentedOdeEquations, + Eqn::V: DefaultDenseMatrix, +{ + type State = RkState; + + fn problem(&self) -> &'a OdeSolverProblem { + self.rk.problem() + } + + fn jacobian(&self) -> Option::M>> { + None + } + + fn mass(&self) -> Option::M>> { + None + } + + fn order(&self) -> usize { + self.rk.order() + } + + fn set_state(&mut self, state: Self::State) { + self.rk.set_state(state); + } + + fn into_state(self) -> RkState { + self.rk.into_state() + } + + fn checkpoint(&mut self) -> Self::State { + self.rk.checkpoint() + } + + fn step(&mut self) -> Result, DiffsolError> { + let mut h = self.rk.start_step()?; + + // loop until step is accepted + let mut nattempts = 0; + let factor = loop { + // start a step attempt + self.rk.start_step_attempt(h, self.augmented_eqn.as_mut()); + for i in 1..self.rk.tableau().s() { + self.rk.do_stage(i, h, self.augmented_eqn.as_mut()); + } + let error_norm = self.rk.error_norm(h, self.augmented_eqn.as_mut()); + let factor = self.rk.factor(error_norm, 1.0); + if error_norm < Eqn::T::one() { + break factor; + } + h *= factor; + nattempts += 1; + self.rk.error_test_fail(h, nattempts)?; + }; + self.rk.step_accepted(h, h * factor, false) + } + + fn set_stop_time(&mut self, tstop: ::T) -> Result<(), DiffsolError> { + self.rk.set_stop_time(tstop) + } + + fn interpolate_sens(&self, t: ::T) -> Result::V>, DiffsolError> { + self.rk.interpolate_sens(t) + } + + fn interpolate(&self, t: ::T) -> Result<::V, DiffsolError> { + self.rk.interpolate(t) + } + + fn interpolate_out(&self, t: ::T) -> Result<::V, DiffsolError> { + self.rk.interpolate_out(t) + } + + fn state(&self) -> StateRef<'_, Eqn::V> { + self.rk.state().as_ref() + } + + fn state_mut(&mut self) -> StateRefMut<'_, Eqn::V> { + self.rk.state_mut().as_mut() + } +} + +#[cfg(test)] +mod test { + use crate::{ + matrix::dense_nalgebra_serial::NalgebraMat, + ode_equations::test_models::exponential_decay::{ + exponential_decay_problem, exponential_decay_problem_adjoint, + exponential_decay_problem_sens, exponential_decay_problem_with_root, + negative_exponential_decay_problem, + }, + ode_solver::tests::{ + setup_test_adjoint, setup_test_adjoint_sum_squares, test_adjoint, + test_adjoint_sum_squares, test_checkpointing, test_interpolate, test_ode_solver, + test_problem, test_state_mut, test_state_mut_on_problem, + }, + Context, DenseMatrix, MatrixCommon, NalgebraLU, NalgebraVec, OdeEquations, OdeSolverMethod, + Op, Vector, VectorView, + }; + + use num_traits::abs; + + type M = NalgebraMat; + type LS = NalgebraLU; + + #[test] + fn explicit_rk_state_mut() { + test_state_mut(test_problem::().tsit45().unwrap()); + } + #[test] + fn explicit_rk_test_interpolate() { + test_interpolate(test_problem::().tsit45().unwrap()); + } + + #[test] + fn explicit_rk_test_checkpointing() { + let (problem, soln) = exponential_decay_problem::(false); + let s1 = problem.tsit45().unwrap(); + let s2 = problem.tsit45().unwrap(); + test_checkpointing(soln, s1, s2); + } + + #[test] + fn explicit_rk_test_state_mut_exponential_decay() { + let (p, soln) = exponential_decay_problem::(false); + let s = p.tsit45().unwrap(); + test_state_mut_on_problem(s, soln); + } + + #[test] + fn explicit_rk_test_nalgebra_negative_exponential_decay() { + let (problem, soln) = negative_exponential_decay_problem::(false); + let mut s = problem.tsit45().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); + } + + #[test] + fn test_tsit45_nalgebra_exponential_decay() { + let (problem, soln) = exponential_decay_problem::(false); + let mut s = problem.tsit45().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + number_of_linear_solver_setups: 0 + number_of_steps: 5 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 0 + "###); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 32 + number_of_jac_muls: 0 + number_of_matrix_evals: 0 + number_of_jac_adj_muls: 0 + "###); + } + + #[cfg(feature = "cuda")] + #[test] + fn test_tsit45_cuda_exponential_decay() { + let (problem, soln) = exponential_decay_problem::>(false); + let mut s = problem.tsit45().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + number_of_linear_solver_setups: 0 + number_of_steps: 5 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 0 + "###); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 32 + number_of_jac_muls: 0 + number_of_matrix_evals: 0 + number_of_jac_adj_muls: 0 + "###); + } + + #[test] + fn test_tsit45_nalgebra_exponential_decay_sens() { + let (problem, soln) = exponential_decay_problem_sens::(false); + let mut s = problem.tsit45_sens().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + number_of_linear_solver_setups: 0 + number_of_steps: 8 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 0 + "###); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 50 + number_of_jac_muls: 98 + number_of_matrix_evals: 0 + number_of_jac_adj_muls: 0 + "###); + } + + #[test] + fn explicit_rk_test_tsit45_exponential_decay_adjoint() { + let (mut problem, soln) = exponential_decay_problem_adjoint::(true); + let final_time = soln.solution_points.last().unwrap().t; + let dgdu = setup_test_adjoint::(&mut problem, soln); + let mut s = problem.tsit45().unwrap(); + let (checkpointer, _y, _t) = s.solve_with_checkpointing(final_time, None).unwrap(); + let adjoint_solver = problem.tsit45_solver_adjoint(checkpointer, None).unwrap(); + test_adjoint(adjoint_solver, dgdu); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 434 + number_of_jac_muls: 8 + number_of_matrix_evals: 4 + number_of_jac_adj_muls: 123 + "###); + } + + #[test] + fn explicit_rk_test_nalgebra_exponential_decay_adjoint_sum_squares() { + let (mut problem, soln) = exponential_decay_problem_adjoint::(false); + let times = soln.solution_points.iter().map(|p| p.t).collect::>(); + let (dgdp, data) = setup_test_adjoint_sum_squares::(&mut problem, times.as_slice()); + let (problem, _soln) = exponential_decay_problem_adjoint::(false); + let mut s = problem.tsit45().unwrap(); + let (checkpointer, soln) = s + .solve_dense_with_checkpointing(times.as_slice(), None) + .unwrap(); + let adjoint_solver = problem + .tsit45_solver_adjoint(checkpointer, Some(dgdp.ncols())) + .unwrap(); + test_adjoint_sum_squares(adjoint_solver, dgdp, soln, data, times.as_slice()); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 747 + number_of_jac_muls: 0 + number_of_matrix_evals: 0 + number_of_jac_adj_muls: 1707 + "###); + } + + #[test] + fn test_tstop_tsit45() { + let (problem, soln) = exponential_decay_problem::(false); + let mut s = problem.tsit45().unwrap(); + test_ode_solver(&mut s, soln, None, true, false); + } + + #[test] + fn test_root_finder_tsit45() { + let (problem, soln) = exponential_decay_problem_with_root::(false); + let mut s = problem.tsit45().unwrap(); + let y = test_ode_solver(&mut s, soln, None, false, false); + assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); + } + + #[test] + fn test_param_sweep_tsit45() { + let (mut problem, _soln) = exponential_decay_problem::(false); + let mut ps = Vec::new(); + for y0 in (1..10).map(f64::from) { + ps.push(problem.context().vector_from_vec(vec![0.1, y0])); + } + + let mut old_soln: Option> = None; + for p in ps { + problem.eqn_mut().set_params(&p); + let mut s = problem.tsit45().unwrap(); + let (ys, _ts) = s.solve(10.0).unwrap(); + // check that the new solution is different from the old one + if let Some(old_soln) = &mut old_soln { + let new_soln = ys.column(ys.ncols() - 1).into_owned(); + let error = new_soln - &*old_soln; + let diff = error + .squared_norm(old_soln, &problem.atol, problem.rtol) + .sqrt(); + assert!(diff > 1.0e-6, "diff: {diff}"); + } + old_soln = Some(ys.column(ys.ncols() - 1).into_owned()); + } + } + + #[cfg(feature = "diffsl-cranelift")] + #[test] + fn test_ball_bounce_tsit45() { + type M = crate::NalgebraMat; + let (x, v, t) = crate::ode_solver::tests::test_ball_bounce( + crate::ode_solver::tests::test_ball_bounce_problem::() + .tsit45() + .unwrap(), + ); + let expected_x = [6.375884661615263]; + let expected_v = [0.6878538646461059]; + let expected_t = [2.5]; + for (i, ((x, v), t)) in x.iter().zip(v.iter()).zip(t.iter()).enumerate() { + assert!((x - expected_x[i]).abs() < 1e-4); + assert!((v - expected_v[i]).abs() < 1e-4); + assert!((t - expected_t[i]).abs() < 1e-4); + } + } +} diff --git a/diffsol/src/ode_solver/method.rs b/diffsol/src/ode_solver/method.rs index 0aba3345..682746e7 100644 --- a/diffsol/src/ode_solver/method.rs +++ b/diffsol/src/ode_solver/method.rs @@ -172,7 +172,7 @@ where /// Get the current order of accuracy of the solver (e.g. explict euler method is first-order) fn order(&self) -> usize; - + /// Using the provided state, solve the problem up to time `final_time` /// Returns a Vec of solution values at timepoints chosen by the solver. /// After the solver has finished, the internal state of the solver is at time `final_time`. diff --git a/diffsol/src/ode_solver/mod.rs b/diffsol/src/ode_solver/mod.rs index ea8fae8b..9cd135cb 100644 --- a/diffsol/src/ode_solver/mod.rs +++ b/diffsol/src/ode_solver/mod.rs @@ -4,6 +4,7 @@ pub mod bdf_state; pub mod builder; pub mod checkpointing; pub mod explicit_rk; +pub mod explicit_sde_rk; pub mod jacobian_update; pub mod method; pub mod problem; diff --git a/diffsol/src/ode_solver/runge_kutta.rs b/diffsol/src/ode_solver/runge_kutta.rs index 063671e0..6abbeb16 100644 --- a/diffsol/src/ode_solver/runge_kutta.rs +++ b/diffsol/src/ode_solver/runge_kutta.rs @@ -212,7 +212,7 @@ where } Ok(ret) } - + pub(crate) fn check_explicit_rk( problem: &'a OdeSolverProblem, tableau: &Tableau, diff --git a/diffsol/src/ode_solver/sdirk.rs b/diffsol/src/ode_solver/sdirk.rs index b0f95353..5a8852a4 100644 --- a/diffsol/src/ode_solver/sdirk.rs +++ b/diffsol/src/ode_solver/sdirk.rs @@ -12,7 +12,7 @@ use crate::{ nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, AugmentedOdeEquations, AugmentedOdeEquationsImplicit, Convergence, DefaultDenseMatrix, DenseMatrix, JacobianUpdate, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op, StateRef, - StateRefMut, + StateRefMut, error::OdeSolverError }; use num_traits::One; @@ -136,6 +136,11 @@ where ) -> Result { let state = rk.state(); + // check that there isn't any diffusion term + if problem.eqn.stoch().is_some() { + return Err(DiffsolError::from(OdeSolverError::StochNotSupported)); + } + // setup linear solver for first step let mut jacobian_update = JacobianUpdate::default(); jacobian_update.update_jacobian(state.h); From 45da843376b6cc0fadcdb4511b3cf5498a6d2f6a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Wed, 24 Sep 2025 12:54:53 +0000 Subject: [PATCH 2/4] merge stoch tableau with det, draft out stages --- diffsol/src/lib.rs | 3 +- diffsol/src/ode_solver/explicit_sde_rk.rs | 4 +- diffsol/src/ode_solver/mod.rs | 1 - diffsol/src/ode_solver/runge_kutta.rs | 195 +++++++++++++++++++--- diffsol/src/ode_solver/sde.rs | 7 - diffsol/src/ode_solver/tableau.rs | 116 +++++++++++-- diffsol/src/op/stoch.rs | 64 ++----- 7 files changed, 291 insertions(+), 99 deletions(-) delete mode 100644 diffsol/src/ode_solver/sde.rs diff --git a/diffsol/src/lib.rs b/diffsol/src/lib.rs index 15e9ace8..67a43cff 100644 --- a/diffsol/src/lib.rs +++ b/diffsol/src/lib.rs @@ -190,7 +190,7 @@ pub use ode_equations::{ sens_equations::SensInit, sens_equations::SensRhs, AugmentedOdeEquations, AugmentedOdeEquationsImplicit, NoAug, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens, OdeEquationsRef, OdeEquationsSens, - OdeEquationsStoch, OdeSolverEquations, + OdeSolverEquations, }; use ode_solver::jacobian_update::JacobianUpdate; pub use ode_solver::sde::SdeSolverMethod; @@ -201,6 +201,7 @@ pub use ode_solver::{ method::AugmentedOdeSolverMethod, method::OdeSolverMethod, method::OdeSolverStopReason, problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::RkState, sensitivities::SensitivitiesOdeSolverMethod, state::OdeSolverState, tableau::Tableau, + tableau_sde::TableauSde, }; pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint}; pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose}; diff --git a/diffsol/src/ode_solver/explicit_sde_rk.rs b/diffsol/src/ode_solver/explicit_sde_rk.rs index 0742beb4..68ab6ffb 100644 --- a/diffsol/src/ode_solver/explicit_sde_rk.rs +++ b/diffsol/src/ode_solver/explicit_sde_rk.rs @@ -64,7 +64,7 @@ where state: RkState, tableau: Tableau, ) -> Result { - Rk::::check_explicit_rk(problem, &tableau)?; + Rk::::check_explicit_sde_rk(problem, &tableau)?; Ok(Self { rk: Rk::new(problem, state, tableau)?, augmented_eqn: None, @@ -77,7 +77,7 @@ where tableau: Tableau, augmented_eqn: AugmentedEqn, ) -> Result { - Rk::::check_explicit_rk(problem, &tableau)?; + Rk::::check_explicit_sde_rk(problem, &tableau)?; Ok(Self { rk: Rk::new_augmented(problem, state, tableau, &augmented_eqn)?, augmented_eqn: Some(augmented_eqn), diff --git a/diffsol/src/ode_solver/mod.rs b/diffsol/src/ode_solver/mod.rs index 9cd135cb..e7ff810a 100644 --- a/diffsol/src/ode_solver/mod.rs +++ b/diffsol/src/ode_solver/mod.rs @@ -9,7 +9,6 @@ pub mod jacobian_update; pub mod method; pub mod problem; pub mod runge_kutta; -pub mod sde; pub mod sdirk; pub mod sdirk_state; pub mod sensitivities; diff --git a/diffsol/src/ode_solver/runge_kutta.rs b/diffsol/src/ode_solver/runge_kutta.rs index 6abbeb16..ef6f6d7e 100644 --- a/diffsol/src/ode_solver/runge_kutta.rs +++ b/diffsol/src/ode_solver/runge_kutta.rs @@ -1,17 +1,20 @@ use crate::error::DiffsolError; use crate::error::OdeSolverError; +use crate::ode_solver::problem; use crate::op::sdirk::SdirkCallable; +use crate::op::stoch; use crate::scale; use crate::AugmentedOdeEquationsImplicit; use crate::OdeEquationsImplicit; use crate::OdeSolverStopReason; use crate::RkState; use crate::RootFinder; +use crate::StochOpKind; use crate::Tableau; use crate::{ ode_solver_error, AugmentedOdeEquations, Convergence, DefaultDenseMatrix, DenseMatrix, MatrixView, NonLinearOp, NonLinearSolver, OdeEquations, OdeSolverProblem, OdeSolverState, Op, - Scalar, Vector, VectorViewMut, + Scalar, Vector, VectorViewMut, StochOp, Matrix }; use num_traits::abs; use num_traits::One; @@ -38,16 +41,18 @@ where problem: &'a OdeSolverProblem, tableau: Tableau, state: RkState, - a_rows: Vec, + a_rows: Vec>, statistics: BdfStatistics, root_finder: Option>, tstop: Option, - diff: M, + diff: Vec, sdiff: Vec, sgdiff: Vec, gdiff: M, old_state: RkState, is_state_mutated: bool, + stoch_eval: ::M, + stoch_y: Option, error: Option, out_error: Option, @@ -80,6 +85,7 @@ where out_error: self.out_error.clone(), sens_error: self.sens_error.clone(), sens_out_error: self.sens_out_error.clone(), + stoch_eval: self.stoch_eval.clone(), } } } @@ -115,16 +121,25 @@ where let nstates = state.y.len(); let order = tableau.s(); + let ctx = problem.context(); + let (nprocess, kind) = if let Some(stoch) = problem.eqn.stoch() { + (stoch.nprocess(), stoch.kind()) + } else { + (0, StochOpKind::Other) + }; let s = tableau.s(); - let mut a_rows = Vec::with_capacity(s); - let ctx = problem.context(); - for i in 0..s { - let mut row = Vec::with_capacity(i); - for j in 0..i { - row.push(tableau.a().get_index(i, j)); + let mut a_rows = Vec::with_capacity(tableau.a().len()); + for a in tableau.a() { + let mut a_rows_i = Vec::with_capacity(s); + for i in 0..s { + let mut row = Vec::with_capacity(i); + for j in 0..i { + row.push(tableau.a().get_index(i, j)); + } + a_rows_i.push(Eqn::V::from_vec(row, ctx.clone())); } - a_rows.push(Eqn::V::from_vec(row, ctx.clone())); + a_rows.push(a_rows_i); } state.set_problem(problem)?; @@ -135,8 +150,17 @@ where } else { None }; - - let diff = M::zeros(nstates, order, ctx.clone()); + let n_stoch_diff = match kind { + StochOpKind::Scalar | StochOpKind::Diagonal => 1, + StochOpKind::Additive => nprocess, + StochOpKind::Other => 0, + }; + let stoch_y = match kind { + StochOpKind::Scalar | StochOpKind::Diagonal => Some(Eqn::V::zeros(nstates, ctx.clone())), + _ => None, + } + let stoch_eval = StochOpKind::Scalar => ::M::zeros(nstates, n_stoch_diff, ctx.clone()); + let diff = vec![M::zeros(nstates, order, ctx.clone()); 1 + n_stoch_diff]; let gdiff_rows = if problem.integrate_out { problem.eqn.out().unwrap().nout() } else { @@ -167,6 +191,8 @@ where root_finder, tstop: None, is_state_mutated: false, + stoch_eval, + stoch_y, diff, gdiff, sdiff: vec![], @@ -213,6 +239,34 @@ where Ok(ret) } + + pub(crate) fn check_explicit_sde_rk( + problem: &'a OdeSolverProblem, + tableau: &Tableau, + ) -> Result<(), DiffsolError> { + // check that the upper triangular and diagonal parts of a are zero + let s = tableau.s(); + for a in tableau.a() { + for i in 0..s { + for j in i..s { + assert_eq!( + a.get_index(i, j), + Eqn::T::zero(), + "Invalid tableau, expected a(i, j) = 0 for i >= j" + ); + } + } + } + // check that first c is 0 + assert_eq!( + c.get_index(0), + Eqn::T::zero(), + "Invalid tableau, expected c(0) = 0" + ); + Ok(()) + } + + pub(crate) fn check_explicit_rk( problem: &'a OdeSolverProblem, tableau: &Tableau, @@ -221,12 +275,21 @@ where if problem.eqn.mass().is_some() { return Err(DiffsolError::from(OdeSolverError::MassMatrixNotSupported)); } + // check that there isn't any stochastic operator + if problem.eqn.stoch().is_some() { + return Err(DiffsolError::from(OdeSolverError::StochNotSupported)); + } + // check that there is only one a matrix + assert_eq!(tableau.a().len(), 1, "Invalid tableau, expected only one a matrix"); + let a = tableau.a()[0]; + let c = tableau.c(); + // check that the upper triangular and diagonal parts of a are zero let s = tableau.s(); for i in 0..s { for j in i..s { assert_eq!( - tableau.a().get_index(i, j), + a.get_index(i, j), Eqn::T::zero(), "Invalid tableau, expected a(i, j) = 0 for i >= j" ); @@ -236,7 +299,7 @@ where // check last row of a is the same as b for i in 0..s { assert_eq!( - tableau.a().get_index(s - 1, i), + a.get_index(s - 1, i), tableau.b().get_index(i), "Invalid tableau, expected a(s-1, i) = b(i)" ); @@ -244,14 +307,14 @@ where // check that last c is 1 assert_eq!( - tableau.c().get_index(s - 1), + c.get_index(s - 1), Eqn::T::one(), "Invalid tableau, expected c(s-1) = 1" ); // check that first c is 0 assert_eq!( - tableau.c().get_index(0), + c.get_index(0), Eqn::T::zero(), "Invalid tableau, expected c(0) = 0" ); @@ -259,10 +322,13 @@ where } pub(crate) fn skip_first_stage(&self) -> bool { - self.tableau.a().get_index(0, 0) == Eqn::T::zero() + self.tableau.a()[0].get_index(0, 0) == Eqn::T::zero() } pub(crate) fn check_sdirk_rk(tableau: &Tableau) -> Result<(), DiffsolError> { + // check that there is only one a matrix + assert_eq!(tableau.a().len(), 1, "Invalid tableau, expected only one a matrix"); + // check that the upper triangular part of a is zero let s = tableau.s(); for i in 0..s { @@ -403,6 +469,7 @@ where } factor } + pub(crate) fn start_step_attempt( &mut self, @@ -412,10 +479,26 @@ where // if start == 1, then we need to compute the first stage // from the last stage of the previous step if self.skip_first_stage() { - self.diff + self.diff[0] .column_mut(0) - .axpy(h, &self.state.dy, Eqn::T::zero()); - + .axpy(h, &self.state.dy, ); + + // for stochastic methods + if let Some(stoch) = self.problem.eqn.stoch() { + stoch.call_inplace(&self.old_state.y, self.old_state.t, &mut self.stoch_eval); + match stoch.kind() { + StochOpKind::Scalar | StochOpKind::Diagonal => { + self.diff[1].column_mut(0).copy_from_view(&self.stoch_eval.column(0)); + }, + StochOpKind::Additive => { + for i in 0..stoch.nprocess() { + self.diff[1 + i].column_mut(0).copy_from_view(&self.stoch_eval.column(i)); + } + }, + StochOpKind::Other => unreachable!("other not supported here"), + } + } + // sensitivities too if augmented_eqn.is_some() { for (sdiff, ds) in self.sdiff.iter_mut().zip(self.state.ds.iter()) { @@ -450,9 +533,9 @@ where .unwrap_or(true); if integrate_main_eqn { self.old_state.y.copy_from(&self.state.y); - self.diff.columns(0, i).gemv_o( + self.diff[0].columns(0, i).gemv_o( Eqn::T::one(), - &self.a_rows[i], + &self.a_rows[0][i], Eqn::T::one(), &mut self.old_state.y, ); @@ -465,6 +548,74 @@ where self.diff .column_mut(i) .axpy(h, &self.old_state.dy, Eqn::T::zero()); + + if let Some(stoch) = self.problem.eqn.stoch() { + match stoch.kind() { + StochOpKind::Scalar => { + self.diff[1].columns(0, i).gemv_o( + int2_dW[0] / h, + &self.a_rows[1][i], + Eqn::T::one(), + &mut self.old_state.y, + ); + + } + StochOpKind::Diagonal => { + let mut a_rows = self.a_rows[1][i].clone(); + a_rows.component_mul_assign(&int2_dW); + self.diff[1].columns(0, i).gemv_o( + Eqn::T::one() / h, + &a_rows, + Eqn::T::one(), + &mut self.old_state.y, + ); + }, + StochOpKind::Additive => { + for l in 0..stoch.nprocess() { + self.diff[1].columns(0, i).gemv_o( + int2_dW[l] / h, + &self.a_rows[1][i], + Eqn::T::one(), + &mut self.old_state.y, + ); + } + }, + StochOpKind::Other => unreachable!("other not supported here"), + } + + // evaluate stochastic operator + if let Some(stoch_y) = &mut self.stoch_y { + stoch_y.copy_from(&self.old_state.y); + self.diff[0].columns(0, i).gemv_o( + Eqn::T::one(), + &self.a_rows[2][i], + Eqn::T::one(), + &mut stoch_y, + ); + self.diff[1].columns(0, i).gemv_o( + h.sqrt(), + &self.a_rows[3][i], + Eqn::T::one(), + &mut stoch_y, + ); + stoch.call_inplace(&stoch_y, t, &mut self.stoch_eval); + } else { + stoch.call_inplace(&self.old_state.y, t, &mut self.stoch_eval); + } + + // update diff with solved dy + match stoch.kind() { + StochOpKind::Scalar | StochOpKind::Diagonal => { + self.diff[1].column_mut(i).copy_from_view(&self.stoch_eval.column(0)); + }, + StochOpKind::Additive => { + for i in 0..stoch.nprocess() { + self.diff[1 + i].column_mut(i).copy_from_view(&self.stoch_eval.column(i)); + } + }, + StochOpKind::Other => unreachable!("other not supported here"), + } + } // calculate dg and store in gdiff if self.problem.integrate_out { diff --git a/diffsol/src/ode_solver/sde.rs b/diffsol/src/ode_solver/sde.rs deleted file mode 100644 index 0bfe2ecd..00000000 --- a/diffsol/src/ode_solver/sde.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::{OdeEquationsStoch, OdeSolverMethod}; - -pub trait SdeSolverMethod<'a, Eqn>: OdeSolverMethod<'a, Eqn> -where - Eqn: OdeEquationsStoch + 'a, -{ -} diff --git a/diffsol/src/ode_solver/tableau.rs b/diffsol/src/ode_solver/tableau.rs index 16489ed8..70ee43cf 100644 --- a/diffsol/src/ode_solver/tableau.rs +++ b/diffsol/src/ode_solver/tableau.rs @@ -20,13 +20,30 @@ use num_traits::{One, Zero}; /// where `be` is the embedded method for error control and `d` is the difference between the main and embedded method. /// /// For continous extension methods, the beta matrix is also included. +/// +/// For SDE method, there can be multiple `a`, `c', and 'd' blocks, the options are: +/// - Single `a`, `c` and `d` block (standard deterministic RK method) +/// - 2 `a`, 1 `c` and 2 `d` blocks (stochastic e.g. Rößler SRA1 method) +/// - 4 `a`, 2 `c` and 4 `d` blocks (stochastic e.g. Rößler SRIW1 method) +/// +/// +/// --------------------- +/// c_0 | a_0 | a_1 | +/// --------------------- +/// c_1 | a_2 | b_3 | +/// ----------------------------- +/// | b | d_0 | d_1 | +/// ------------------------- +/// | d_2 | d_3 | +/// ----------------- +/// /// #[derive(Clone)] pub struct Tableau { - a: M, + a: Vec, b: M::V, - c: M::V, - d: M::V, + c: Vec, + d: Vec, order: usize, beta: Option, } @@ -92,7 +109,7 @@ impl Tableau { let order = 2; - Self::new(a, b, c, d, order, Some(beta)) + Self::new(vec![a], b, vec![c], vec![d], order, Some(beta)) } /// A third order ESDIRK method @@ -153,7 +170,7 @@ impl Tableau { ctx.clone(), ); - Self::new(a, b, c, d, 3, None) + Self::new(vec![a], b, vec![c], vec![d], 3, None) } pub fn tsit45(ctx: M::C) -> Self { @@ -297,15 +314,84 @@ impl Tableau { ); let order = 4; - Self::new(a, b, c, d, order, Some(beta)) + Self::new(vec![a], b, vec![c], vec![d], order, Some(beta)) + } + + /// Rößler SRIW1 method + /// from Rößler, A. (2010). Runge–Kutta methods for the strong approximation of solutions of stochastic differential equations. SIAM Journal on Numerical Analysis, 48(3), 922-952. + pub fn robler_sriw1(ctx: M::C) -> Self { + let a = vec![ + M::from_vec(4, 4, + vec![ + M::T::zero(), M::T::from(3.0/4.0), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + ], ctx.clone()), + + M::from_vec(4, 4, + vec![ + M::T::zero(), M::T::from(3.0/2.0), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + ], ctx.clone()), + M::from_vec(4, 4, + vec![ + M::T::zero(), M::T::from(1.0/4.0), M::T::one(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::from(1.0/4.0), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + ], ctx.clone()), + M::from_vec(4, 4, + vec![ + M::T::zero(), M::T::from(1.0/2.0), M::T::from(-1.0), M::T::from(-5.0), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::from(3.0), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::from(1.0/2.0), + M::T::zero(), M::T::zero(), M::T::zero(), M::T::zero(), + ], ctx.clone()), + ]; + let c = vec![ + M::V::from_vec(vec![M::T::zero(), M::T::from(3.0/4.0), M::T::zero(), M::T::zero()], ctx.clone()), + M::V::from_vec(vec![M::T::zero(), M::T::from(1.0/4.0), M::T::one(), M::T::from(1.0/4.0)], ctx.clone()), + ]; + let b = M::V::from_vec(vec![M::T::from(1.0/3.0), M::T::from(2.0/3.0), M::T::zero(), M::T::zero()], ctx.clone()); + let d = vec![ + M::V::from_vec(vec![M::T::from(-1.0), M::T::from(4.0/3.0), M::T::from(2.0/3.0), M::T::zero()], ctx.clone()), + M::V::from_vec(vec![M::T::from(-1.0), M::T::from(4.0/3.0), M::T::from(-1.0/3.0), M::T::zero()], ctx.clone()), + M::V::from_vec(vec![M::T::from(2.0), M::T::from(-4.0/3.0), M::T::from(-2.0/3.0), M::T::zero()], ctx.clone()), + M::V::from_vec(vec![M::T::from(-2.0), M::T::from(5.0/3.0), M::T::from(-2.0/3.0), M::T::one()], ctx.clone()), + ]; + + let order = 2; + + Self::new(a, b, c, d, order, None) } pub fn new(a: M, b: M::V, c: M::V, d: M::V, order: usize, beta: Option) -> Self { let s = c.len(); - assert_eq!(a.ncols(), s, "Invalid number of rows in a, expected {s}"); - assert_eq!(a.nrows(), s, "Invalid number of columns in a, expected {s}",); + // length of a should be 1, 2 or 4 + assert!(a.len() == 1 || a.len() == 2 || a.len() == 4, "Invalid length of a, expected 1, 2 or 4"); + // length of c should be 1 or 2 + assert!(c.len() == 1 || c.len() == 2, "Invalid length of c, expected 1 or 2"); + // length of d should be 1, 2 or 4 + assert!(d.len() == 1 || d.len() == 2 || d.len() == 4, "Invalid length of d, expected 1, 2 or 4"); + let expected_c_len = if a.len() == 1 { 1 } else { a.len() / 2 }; + assert_eq!(c.len(), expected_c_len, "Invalid length of c, expected {expected_c_len}"); + let expected_d_len = a.len(); + assert_eq!(d.len(), expected_d_len, "Invalid length of d, expected {expected_d_len}"); + for a_i in &a { + assert_eq!(a_i.ncols(), s, "Invalid number of columns in a_i, expected {s}"); + assert_eq!(a_i.nrows(), s, "Invalid number of rows in a_i, expected {s}"); + } assert_eq!(b.len(), s, "Invalid number of elements in b, expected {s}",); - assert_eq!(c.len(), s, "Invalid number of elements in c, expected {s}",); + for d_i in &d { + assert_eq!(d_i.ncols(), s, "Invalid number of columns in d_i, expected {s}"); + assert_eq!(d_i.nrows(), s, "Invalid number of rows in d_i, expected {s}"); + } + for c_i in &c { + assert_eq!(c_i.len(), s, "Invalid number of elements in c_i, expected {s}"); + } if let Some(beta) = &beta { assert_eq!( beta.nrows(), @@ -331,20 +417,20 @@ impl Tableau { self.c.len() } - pub fn a(&self) -> &M { - &self.a + pub fn a(&self) -> &[M] { + self.a.as_slice() } pub fn b(&self) -> &M::V { &self.b } - pub fn c(&self) -> &M::V { - &self.c + pub fn c(&self) -> &[M::V] { + self.c.as_slice() } - pub fn d(&self) -> &M::V { - &self.d + pub fn d(&self) -> &[M::V] { + self.d.as_slice() } pub fn beta(&self) -> Option<&M> { diff --git a/diffsol/src/op/stoch.rs b/diffsol/src/op/stoch.rs index 7ac61b01..1e1bde50 100644 --- a/diffsol/src/op/stoch.rs +++ b/diffsol/src/op/stoch.rs @@ -1,10 +1,8 @@ use super::Op; -use crate::{Scalar, Vector}; +use crate::{DefaultDenseMatrix, Scalar, Vector}; use num_traits::{One, Zero}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StochOpKind { - Zero, +enum StochOpKind { Scalar, Diagonal, Additive, @@ -12,56 +10,20 @@ pub enum StochOpKind { } /// Stochastic differential equation (SDE) operations. +/// +/// In general, this operator computes `F(x, t)`, where `F` is a matrix of size `nstates() x nprocess()`. +/// The matrix `F` is computed by the [Self::call_inplace] method, which returns a dense matrix `y`. +/// The `kind` method returns the type of stochastic operation, either `Scalar`, `Diagonal`, `Additive`, or `Other`, +/// and the `kind` determines how `y` is interpreted. /// -/// For scalar noise, nprocess is 1. -/// For diagonal noise, y_i only depends on x_i and d_w_i. -/// For additive noise, y_i does not depend on x_i. +/// For scalar noise, `y` is a matrix with one column, and the noise is applied as `y * dW`, where `dW` is a scalar Wiener increment. +/// For diagonal noise, `y` is a matrix with one column, which is interpreted as the diagonal of the matrix `F(x, t)`. The noise is applied as `F * dW`, where `dW` is a vector of independent Wiener increments. +/// For additive noise, `y` is a full matrix with `nprocess()` columns that does not depend on `x`, and the noise is applied as `F * dW`, where `dW` is a vector of Wiener increments. +/// Diffsol does not support other types of noise, but the `Other` kind is provided for completeness. pub trait StochOp: Op { + fn kind(&self) -> StochOpKind; fn nprocess(&self) -> usize; - fn process_inplace(&self, x: &Self::V, d_w: &Self::V, t: Self::T, y: &mut [Self::V]); - fn kind(&self) -> StochOpKind { - if self.nprocess() == 0 { - return StochOpKind::Zero; - } - if self.nprocess() == 1 { - return StochOpKind::Scalar; - } - let mut y = vec![Self::V::zeros(self.nout(), self.context().clone()); self.nprocess()]; - let mut x = Self::V::zeros(self.nstates(), self.context().clone()); - x.fill(Self::T::NAN); - let mut d_w = Self::V::zeros(self.nprocess(), self.context().clone()); - d_w.fill(Self::T::one()); - let t = Self::T::zero(); - self.process_inplace(&x, &d_w, t, &mut y); - // if none of the outputs has nans, it is additive - if y.iter() - .all(|y_j| !y_j.clone_as_vec().iter().any(|&val| val.is_nan())) - { - return StochOpKind::Additive; - } - - x.fill(Self::T::one()); - - for i in 0..self.nprocess() { - if i != 0 { - d_w.set_index(i - 1, Self::T::one()); - } - d_w.set_index(i, Self::T::NAN); - self.process_inplace(&x, &d_w, t, &mut y); - - // if any of the y[j] j != i has nans, it is other - for (j, y_j) in y.iter().enumerate() { - if j != i { - let has_nans = y_j.clone_as_vec().iter().any(|&val| val.is_nan()); - if has_nans { - return StochOpKind::Other; - } - } - } - } - // must be diagonal - StochOpKind::Diagonal - } + fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut ::M) where Self::V: DefaultDenseMatrix; } #[cfg(test)] From a4ffd5c731e9d807034a3d0271433778dec32e9f Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 25 Sep 2025 17:04:34 +0000 Subject: [PATCH 3/4] stoch now can be a nonlin or lin op --- diffsol/src/lib.rs | 2 +- diffsol/src/ode_equations/mod.rs | 15 ++- diffsol/src/ode_solver/runge_kutta.rs | 174 +++++++++++++------------- diffsol/src/op/linear_op.rs | 2 +- diffsol/src/op/stoch.rs | 4 +- 5 files changed, 105 insertions(+), 92 deletions(-) diff --git a/diffsol/src/lib.rs b/diffsol/src/lib.rs index 67a43cff..1ca5bf31 100644 --- a/diffsol/src/lib.rs +++ b/diffsol/src/lib.rs @@ -190,7 +190,7 @@ pub use ode_equations::{ sens_equations::SensInit, sens_equations::SensRhs, AugmentedOdeEquations, AugmentedOdeEquationsImplicit, NoAug, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens, OdeEquationsRef, OdeEquationsSens, - OdeSolverEquations, + OdeSolverEquations, StochEnum, }; use ode_solver::jacobian_update::JacobianUpdate; pub use ode_solver::sde::SdeSolverMethod; diff --git a/diffsol/src/ode_equations/mod.rs b/diffsol/src/ode_equations/mod.rs index dd401639..32f8957e 100644 --- a/diffsol/src/ode_equations/mod.rs +++ b/diffsol/src/ode_equations/mod.rs @@ -197,7 +197,8 @@ pub trait OdeEquationsRef<'a, ImplicitBounds: Sealed = Bounds<&'a Self>>: Op { type Root: NonLinearOp; type Init: ConstantOp; type Out: NonLinearOp; - type Stoch: StochOp; + type Stoch: NonLinearOp; + type StochAdditive: LinearOp; } impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T { @@ -207,6 +208,14 @@ impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T { type Init = >::Init; type Out = >::Out; type Stoch = >::Stoch; + type StochAdditive = >::StochAdditive; +} + +pub enum StochEnum { + Scalar(A), + Diagonal(A), + Additive(B), + None, } // seal the trait so that users must use the provided default type for ImplicitBounds @@ -251,8 +260,8 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> { None } - fn stoch(&self) -> Option<>::Stoch> { - None + fn stoch(&self) -> StochEnum<>::Stoch, >::StochAdditive> { + StochEnum::None } /// returns the initial condition, i.e. `y(t)`, where `t` is the initial time diff --git a/diffsol/src/ode_solver/runge_kutta.rs b/diffsol/src/ode_solver/runge_kutta.rs index ef6f6d7e..1b360fd1 100644 --- a/diffsol/src/ode_solver/runge_kutta.rs +++ b/diffsol/src/ode_solver/runge_kutta.rs @@ -9,6 +9,7 @@ use crate::OdeEquationsImplicit; use crate::OdeSolverStopReason; use crate::RkState; use crate::RootFinder; +use crate::StochEnum; use crate::StochOpKind; use crate::Tableau; use crate::{ @@ -51,7 +52,7 @@ where gdiff: M, old_state: RkState, is_state_mutated: bool, - stoch_eval: ::M, + stoch_dy: Option, stoch_y: Option, error: Option, @@ -85,7 +86,7 @@ where out_error: self.out_error.clone(), sens_error: self.sens_error.clone(), sens_out_error: self.sens_out_error.clone(), - stoch_eval: self.stoch_eval.clone(), + stoch_dy: self.stoch_dy.clone(), } } } @@ -122,11 +123,6 @@ where let nstates = state.y.len(); let order = tableau.s(); let ctx = problem.context(); - let (nprocess, kind) = if let Some(stoch) = problem.eqn.stoch() { - (stoch.nprocess(), stoch.kind()) - } else { - (0, StochOpKind::Other) - }; let s = tableau.s(); let mut a_rows = Vec::with_capacity(tableau.a().len()); @@ -150,17 +146,26 @@ where } else { None }; - let n_stoch_diff = match kind { - StochOpKind::Scalar | StochOpKind::Diagonal => 1, - StochOpKind::Additive => nprocess, - StochOpKind::Other => 0, + let n_stoch_diff = match problem.eqn.stoch() { + StochEnum::Scalar(_) => 1, + StochEnum::Diagonal(_) => 2, + StochEnum::Additive(_) => 0, + StochEnum::None => 0, }; - let stoch_y = match kind { - StochOpKind::Scalar | StochOpKind::Diagonal => Some(Eqn::V::zeros(nstates, ctx.clone())), + let stoch_y = match problem.eqn.stoch() { + StochEnum::Scalar(_) | StochEnum::Diagonal(_) => + Some(Eqn::V::zeros(nstates, ctx.clone())), _ => None, + }; + let stoch_dy = match problem.eqn.stoch() { + StochEnum::Scalar(_) | StochEnum::Diagonal(_) => + Some(Eqn::V::zeros(nstates, ctx.clone())), + _ => None, + }; + let mut diff = vec![M::zeros(nstates, order, ctx.clone()); 1 + n_stoch_diff]; + if let StochEnum::Additive(op) = problem.eqn.stoch() { + diff.push(M::zeros(op.nstates(), order, ctx.clone())); } - let stoch_eval = StochOpKind::Scalar => ::M::zeros(nstates, n_stoch_diff, ctx.clone()); - let diff = vec![M::zeros(nstates, order, ctx.clone()); 1 + n_stoch_diff]; let gdiff_rows = if problem.integrate_out { problem.eqn.out().unwrap().nout() } else { @@ -191,7 +196,7 @@ where root_finder, tstop: None, is_state_mutated: false, - stoch_eval, + stoch_dy, stoch_y, diff, gdiff, @@ -484,21 +489,30 @@ where .axpy(h, &self.state.dy, ); // for stochastic methods - if let Some(stoch) = self.problem.eqn.stoch() { - stoch.call_inplace(&self.old_state.y, self.old_state.t, &mut self.stoch_eval); - match stoch.kind() { - StochOpKind::Scalar | StochOpKind::Diagonal => { - self.diff[1].column_mut(0).copy_from_view(&self.stoch_eval.column(0)); - }, - StochOpKind::Additive => { - for i in 0..stoch.nprocess() { - self.diff[1 + i].column_mut(0).copy_from_view(&self.stoch_eval.column(i)); - } - }, - StochOpKind::Other => unreachable!("other not supported here"), - } - } - + match self.problem.eqn.stoch() { + StochEnum::Scalar(op) => { + let stoch_dy = self.stoch_dy.as_mut().unwrap(); + op.call_inplace(&self.old_state.y, self.old_state.t, stoch_dy); + self.diff[1].column_mut(0).copy_from(stoch_dy); + }, + StochEnum::Diagonal(op) => { + let stoch_dy = self.stoch_dy.as_mut().unwrap(); + op.call_inplace(&self.old_state.y, self.old_state.t, stoch_dy); + self.diff[1].column_mut(0).copy_from(stoch_dy); + stoch_dy.component_mul_assign(int2_dW); + self.diff[2].column_mut(0).copy_from(stoch_dy); + }, + StochEnum::Additive(op) => { + for a in self.a_rows[1][0].iter() { + + } + let stoch_dy = self.stoch_dy.as_mut().unwrap(); + op.call_inplace(self.old_state.t, stoch_dy); + self.diff[1].column_mut(0).copy_from(stoch_dy); + }, + _ => (), + }; + // sensitivities too if augmented_eqn.is_some() { for (sdiff, ds) in self.sdiff.iter_mut().zip(self.state.ds.iter()) { @@ -548,43 +562,14 @@ where self.diff .column_mut(i) .axpy(h, &self.old_state.dy, Eqn::T::zero()); - - if let Some(stoch) = self.problem.eqn.stoch() { - match stoch.kind() { - StochOpKind::Scalar => { - self.diff[1].columns(0, i).gemv_o( - int2_dW[0] / h, - &self.a_rows[1][i], - Eqn::T::one(), - &mut self.old_state.y, - ); - - } - StochOpKind::Diagonal => { - let mut a_rows = self.a_rows[1][i].clone(); - a_rows.component_mul_assign(&int2_dW); - self.diff[1].columns(0, i).gemv_o( - Eqn::T::one() / h, - &a_rows, - Eqn::T::one(), - &mut self.old_state.y, - ); - }, - StochOpKind::Additive => { - for l in 0..stoch.nprocess() { - self.diff[1].columns(0, i).gemv_o( - int2_dW[l] / h, - &self.a_rows[1][i], - Eqn::T::one(), - &mut self.old_state.y, - ); - } - }, - StochOpKind::Other => unreachable!("other not supported here"), - } - - // evaluate stochastic operator - if let Some(stoch_y) = &mut self.stoch_y { + match self.problem.eqn.stoch() { + StochEnum::Scalar(op) => { + self.diff[1].columns(0, i).gemv_o( + int2_dW[0] / h, + &self.a_rows[1][i], + Eqn::T::one(), + &mut self.old_state.y, + ); stoch_y.copy_from(&self.old_state.y); self.diff[0].columns(0, i).gemv_o( Eqn::T::one(), @@ -598,24 +583,43 @@ where Eqn::T::one(), &mut stoch_y, ); - stoch.call_inplace(&stoch_y, t, &mut self.stoch_eval); - } else { - stoch.call_inplace(&self.old_state.y, t, &mut self.stoch_eval); - } - - // update diff with solved dy - match stoch.kind() { - StochOpKind::Scalar | StochOpKind::Diagonal => { - self.diff[1].column_mut(i).copy_from_view(&self.stoch_eval.column(0)); - }, - StochOpKind::Additive => { - for i in 0..stoch.nprocess() { - self.diff[1 + i].column_mut(i).copy_from_view(&self.stoch_eval.column(i)); - } - }, - StochOpKind::Other => unreachable!("other not supported here"), + op.call_inplace(&stoch_y, t, &mut self.stoch_dy); + self.diff[1].column_mut(i).copy_from(&self.stoch_dy); + } - } + StochEnum::Diagonal(op) => { + self.diff[2].columns(0, i).gemv_o( + 1 / h, + &self.a_rows[1][i], + Eqn::T::one(), + &mut self.old_state.y, + ); + stoch_y.copy_from(&self.old_state.y); + self.diff[0].columns(0, i).gemv_o( + Eqn::T::one(), + &self.a_rows[2][i], + Eqn::T::one(), + &mut stoch_y, + ); + self.diff[1].columns(0, i).gemv_o( + h.sqrt(), + &self.a_rows[3][i], + Eqn::T::one(), + &mut stoch_y, + ); + op.call_inplace(&stoch_y, t, &mut self.stoch_dy); + self.diff[1].column_mut(i).copy_from(&self.stoch_dy); + self.stoch_dy.component_mul_inplace(&int2_dW); + self.diff[2].column_mut(i).copy_from(&self.stoch_dy); + }, + StochEnum::Additive(lin_op) => { + for s in 0..i { + lin_op.gemv_inplace(int2_dW * self.a_rows[1][s] / h, t + self.tableau.c[s], Eqn::T::one(), &mut self.old_state.y); + } + }, + StochEnum::None => (), + }; + // calculate dg and store in gdiff if self.problem.integrate_out { diff --git a/diffsol/src/op/linear_op.rs b/diffsol/src/op/linear_op.rs index 257f78e8..1075ef92 100644 --- a/diffsol/src/op/linear_op.rs +++ b/diffsol/src/op/linear_op.rs @@ -5,7 +5,7 @@ use num_traits::{One, Zero}; /// LinearOp is a trait for linear operators (i.e. they only depend linearly on the input `x`), see [crate::NonLinearOp] for a non-linear op. /// /// An example of a linear operator is a matrix-vector product `y = A(t) * x`, where `A(t)` is a matrix. -/// It extends the [Op] trait with methods for calling the operator via a GEMV-like operation (i.e. `y = t * A * x + beta * y`), and for computing the matrix representation of the operator. +/// It extends the [Op] trait with methods for calling the operator via a GEMV-like operation (i.e. `y = A(t) * x + beta * y`), and for computing the matrix representation of the operator. pub trait LinearOp: Op { /// Compute the operator `y = A(t) * x` at a given state and time, the default implementation uses [Self::gemv_inplace]. fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { diff --git a/diffsol/src/op/stoch.rs b/diffsol/src/op/stoch.rs index 1e1bde50..826c0254 100644 --- a/diffsol/src/op/stoch.rs +++ b/diffsol/src/op/stoch.rs @@ -17,13 +17,13 @@ enum StochOpKind { /// and the `kind` determines how `y` is interpreted. /// /// For scalar noise, `y` is a matrix with one column, and the noise is applied as `y * dW`, where `dW` is a scalar Wiener increment. -/// For diagonal noise, `y` is a matrix with one column, which is interpreted as the diagonal of the matrix `F(x, t)`. The noise is applied as `F * dW`, where `dW` is a vector of independent Wiener increments. +/// For diagonal noise, `y` is a diagonal matrix, which is interpreted as the diagonal of the matrix `F(x, t)`. The noise is applied as `F * dW`, where `dW` is a vector of independent Wiener increments. /// For additive noise, `y` is a full matrix with `nprocess()` columns that does not depend on `x`, and the noise is applied as `F * dW`, where `dW` is a vector of Wiener increments. /// Diffsol does not support other types of noise, but the `Other` kind is provided for completeness. pub trait StochOp: Op { fn kind(&self) -> StochOpKind; fn nprocess(&self) -> usize; - fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut ::M) where Self::V: DefaultDenseMatrix; + fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M); } #[cfg(test)] From bc2fbb0560ad280026872b76b57335fcce57eeb2 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sat, 27 Sep 2025 18:52:03 +0000 Subject: [PATCH 4/4] add scaled gemv --- diffsol/src/matrix/mod.rs | 18 ++++++++++++++++++ diffsol/src/ode_solver/runge_kutta.rs | 15 +++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/diffsol/src/matrix/mod.rs b/diffsol/src/matrix/mod.rs index b1e632a8..36b549f0 100644 --- a/diffsol/src/matrix/mod.rs +++ b/diffsol/src/matrix/mod.rs @@ -118,6 +118,8 @@ pub trait MatrixView<'a>: type Owned; fn into_owned(self) -> Self::Owned; + + /// Perform a matrix-vector multiplication `y = self * x + beta * y`. fn gemv_v( @@ -129,6 +131,19 @@ pub trait MatrixView<'a>: ); fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V); + + /// Perform a matrix-vector multiplication that is scaled by a vector instead of a scalar `y += alpha .* self * x`. + fn scaled_gemv_o( + &self, + alpha: &Self::V, + x: &Self::V, + y: &mut Self::V, + ) { + let mut temp = Self::V::zeros(y.len(), self.context().clone()); + self.gemv(Self::T::one(), x, Self::T::zero(), &mut temp); + temp.mul_assign(alpha); + y.add_assign(&temp); + } } /// A base matrix trait (including sparse and dense matrices) @@ -153,6 +168,9 @@ pub trait Matrix: MatrixCommon + Mul, Output = Self> + Clone + 's /// Perform a matrix-vector multiplication `y = alpha * self * x + beta * y`. fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V); + + + /// Copy the contents of `other` into `self` fn copy_from(&mut self, other: &Self); diff --git a/diffsol/src/ode_solver/runge_kutta.rs b/diffsol/src/ode_solver/runge_kutta.rs index 1b360fd1..85e9a8c4 100644 --- a/diffsol/src/ode_solver/runge_kutta.rs +++ b/diffsol/src/ode_solver/runge_kutta.rs @@ -131,6 +131,7 @@ where for i in 0..s { let mut row = Vec::with_capacity(i); for j in 0..i { + //TODO: probably could just not push if the value is 0?????? row.push(tableau.a().get_index(i, j)); } a_rows_i.push(Eqn::V::from_vec(row, ctx.clone())); @@ -565,7 +566,7 @@ where match self.problem.eqn.stoch() { StochEnum::Scalar(op) => { self.diff[1].columns(0, i).gemv_o( - int2_dW[0] / h, + int2_dW_div_h[0], &self.a_rows[1][i], Eqn::T::one(), &mut self.old_state.y, @@ -588,8 +589,8 @@ where } StochEnum::Diagonal(op) => { - self.diff[2].columns(0, i).gemv_o( - 1 / h, + self.diff[1].columns(0, i).scaled_gemv_o( + int2_dW_div_h, &self.a_rows[1][i], Eqn::T::one(), &mut self.old_state.y, @@ -609,12 +610,14 @@ where ); op.call_inplace(&stoch_y, t, &mut self.stoch_dy); self.diff[1].column_mut(i).copy_from(&self.stoch_dy); - self.stoch_dy.component_mul_inplace(&int2_dW); - self.diff[2].column_mut(i).copy_from(&self.stoch_dy); }, StochEnum::Additive(lin_op) => { for s in 0..i { - lin_op.gemv_inplace(int2_dW * self.a_rows[1][s] / h, t + self.tableau.c[s], Eqn::T::one(), &mut self.old_state.y); + if self.a_rows[1][i][s] != Eqn::T::zero() { + // TODO: move tmp into state to avoid reallocation + let tmp = int2_dw_div_h * self.a_rows[1][i][s]; + lin_op.gemv_inplace(tmp, t + self.tableau.c[s], Eqn::T::one(), &mut self.old_state.y); + } } }, StochEnum::None => (),