From bb219bc0fd397a65ac73afd7039b386be51fb90b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 8 May 2025 11:47:31 +0200 Subject: [PATCH 1/3] fix(stan): rng for generated quantities Generated quantities have their own random number generator. The seed for this generator did not depend on the global seed of the model, so that random stream for two different sampler runs were reusing the same randomness. --- src/stan.rs | 7 ++++--- tests/test_stan.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/stan.rs b/src/stan.rs index e258096..6cf92c3 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -531,8 +531,8 @@ impl Model for StanModel { fn new_trace<'a, S: Settings, R: rand::Rng + ?Sized>( &'a self, - _rng: &mut R, - chain: u64, + rng: &mut R, + _chain: u64, settings: &S, ) -> anyhow::Result> { let draws = settings.hint_num_tune() + settings.hint_num_draws(); @@ -541,7 +541,8 @@ impl Model for StanModel { .iter() .map(|var| Vec::with_capacity(var.size * draws)) .collect(); - let rng = self.model.new_rng(chain as u32)?; + let seed = rng.next_u32(); + let rng = self.model.new_rng(seed)?; let buffer = vec![0f64; self.model.param_num(true, true)]; Ok(StanTrace { model: self, diff --git a/tests/test_stan.py b/tests/test_stan.py index 89201c5..f7ba498 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -27,6 +27,43 @@ def test_stan_model(): trace.posterior.a # noqa: B018 +@pytest.mark.stan +def test_seed(): + model = """ + data {} + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + generated quantities { + real b = normal_rng(0, 1); + } + """ + + compiled_model = nutpie.compile_stan_model(code=model) + trace = nutpie.sample(compiled_model, seed=42) + trace2 = nutpie.sample(compiled_model, seed=42) + trace3 = nutpie.sample(compiled_model, seed=43) + + assert np.allclose(trace.posterior.a, trace2.posterior.a) + assert np.allclose(trace.posterior.b, trace2.posterior.b) + + assert not np.allclose(trace.posterior.a, trace3.posterior.a) + assert not np.allclose(trace.posterior.b, trace3.posterior.b) + # Check that all chains are pairwise different + for i in range(len(trace.posterior.a)): + for j in range(i + 1, len(trace.posterior.a)): + assert not np.allclose(trace.posterior.a[i], trace.posterior.a[j]) + assert not np.allclose(trace.posterior.b[i], trace.posterior.b[j]) + # Check that all chains are pairwise different between seeds + for i in range(len(trace.posterior.a)): + for j in range(len(trace3.posterior.a)): + assert not np.allclose(trace.posterior.a[i], trace3.posterior.a[j]) + assert not np.allclose(trace.posterior.b[i], trace3.posterior.b[j]) + + @pytest.mark.stan def test_stan_model_data(): model = """ From d830c10a3882bf993df83023602c57eefbe7deff Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 6 May 2025 21:30:17 +0200 Subject: [PATCH 2/3] fix: correctly handle tuples in stan traces --- src/stan.rs | 562 ++++++++++++++++++++++++++++++++++++++++----- tests/test_stan.py | 84 +++++++ 2 files changed, 589 insertions(+), 57 deletions(-) diff --git a/src/stan.rs b/src/stan.rs index 6cf92c3..f1e6d54 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::{ffi::CString, path::PathBuf}; -use anyhow::Context; +use anyhow::{bail, Context}; use arrow::array::{Array, FixedSizeListArray, Float64Array, StructArray}; use arrow::datatypes::{DataType, Field}; use bridgestan::open_library; @@ -26,7 +26,7 @@ type InnerModel = bridgestan::Model>; #[derive(Clone)] pub struct StanLibrary(Arc); -#[derive(Clone)] +#[derive(Clone, Debug)] struct Parameter { name: String, shape: Vec, @@ -40,7 +40,7 @@ impl StanLibrary { #[new] fn new(path: PathBuf) -> PyResult { let lib = open_library(path) - .map_err(|e| PyValueError::new_err(format!("Could not open stan libray: {}", e)))?; + .map_err(|e| PyValueError::new_err(format!("Could not open stan libray: {e}")))?; Ok(Self(Arc::new(lib))) } } @@ -64,6 +64,16 @@ impl StanVariable { fn size(&self) -> usize { self.0.size } + + #[getter] + fn start_idx(&self) -> usize { + self.0.start_idx + } + + #[getter] + fn end_idx(&self) -> usize { + self.0.end_idx + } } #[pyclass] @@ -75,68 +85,155 @@ pub struct StanModel { } /// Return meta information about the constrained parameters of the model -fn params( - model: &InnerModel, - include_tp: bool, - include_gq: bool, -) -> anyhow::Result> { - let var_string = model.param_names(include_tp, include_gq); - let name_idxs: anyhow::Result)>> = var_string +fn params(var_string: &str) -> anyhow::Result> { + // Parse each variable string into (name, is_complex, indices) + let parsed_variables: anyhow::Result)>> = var_string .split(',') .map(|var| { - let mut parts = var.split('.'); - let name = parts - .next() - .ok_or_else(|| anyhow::Error::msg("Invalid parameter name"))?; - let idxs: anyhow::Result> = parts - .map(|mut idx| { - if idx == "real" { - idx = "1"; - } - if idx == "imag" { - idx = "2"; - } - let idx: usize = idx - .parse() - .map_err(|_| anyhow::Error::msg("Invalid parameter name"))?; - Ok(idx - 1) - }) - .collect(); - Ok((name, idxs?)) + let mut indices = vec![]; + let mut remaining = var; + let mut complex_suffix = None; + + // Parse from right to left, extracting indices and checking for complex type + while let Some(idx) = remaining.rfind('.') { + let suffix = &remaining[(idx + 1)..]; + + // Handle complex number suffixes + if suffix == "real" || suffix == "imag" { + complex_suffix = Some(suffix); + remaining = &remaining[..idx]; + continue; + } + + // Try to parse as index + if let Ok(index) = suffix.parse::() { + // Convert from 1-based to 0-based indexing + let zero_based_idx = index.checked_sub(1).ok_or_else(|| { + anyhow::Error::msg("Invalid parameter index (must be > 0)") + })?; + + indices.push(zero_based_idx); + remaining = &remaining[..idx]; + } else { + // Not a number - this is part of the variable name + break; + } + } + + // Variable name is what remains + let name = remaining.trim().to_string(); + + // Reverse indices since we parsed right-to-left + indices.reverse(); + + Ok((name, complex_suffix.is_some(), indices)) }) .collect(); + // Group variables by name and build Parameter objects let mut variables = Vec::new(); let mut start_idx = 0; - for (name, idxs) in &name_idxs?.iter().chunk_by(|(name, _)| name) { - let mut shape: Vec = idxs - .map(|(_name, idx)| idx) - .fold(None, |acc, elem| { - let mut shape = acc.unwrap_or(elem.clone()); - shape - .iter_mut() - .zip_eq(elem.iter()) - .for_each(|(old, &new)| { - *old = new.max(*old); - }); - Some(shape) - }) - .unwrap_or(vec![]); - shape.iter_mut().for_each(|max_idx| *max_idx += 1); + + for (name, group) in &parsed_variables?.iter().chunk_by(|(name, _, _)| name) { + // Find maximum shape and check if this is a complex variable + let (shape, is_complex) = determine_variable_shape(group) + .context(format!("Error while parsing stan variable {}", name))?; + + // Calculate total size of this variable let size = shape.iter().product(); - let end_idx = start_idx + size; - variables.push(Parameter { - name: name.to_string(), - shape, - size, - start_idx, - end_idx, - }); + let mut end_idx = start_idx + size; + + // Create Parameter objects (one for real and one for imag if complex) + if is_complex { + variables.push(Parameter { + name: format!("{}.real", name), + shape: shape.clone(), + size, + start_idx, + end_idx, + }); + start_idx = end_idx; + end_idx = start_idx + size; + variables.push(Parameter { + name: format!("{}.imag", name), + shape, + size, + start_idx, + end_idx, + }); + } else { + variables.push(Parameter { + name: name.to_string(), + shape, + size, + start_idx, + end_idx, + }); + } + + // Move to the next variable start_idx = end_idx; } + Ok(variables) } +// Helper function to determine the shape and complex flag for a group of variables +fn determine_variable_shape<'a, I>(group: I) -> anyhow::Result<(Vec, bool)> +where + I: Iterator)>, +{ + let group = group.collect_vec(); + + let (mut shape, is_complex) = group + .iter() + .map(|&(_, is_complex, ref idx)| (idx, is_complex)) + .fold(None, |acc, (elem_index, &elem_is_complex)| { + let (mut shape, is_complex) = acc.unwrap_or((elem_index.clone(), elem_is_complex)); + assert!( + is_complex == elem_is_complex, + "Inconsistent complex flags for same variable" + ); + + // Find maximum index in each dimension + shape + .iter_mut() + .zip_eq(elem_index.iter()) + .for_each(|(old, &new)| { + *old = new.max(*old); + }); + + Some((shape, is_complex)) + }) + .expect("List of variable entries cannot be empty"); + + shape.iter_mut().for_each(|max_idx| *max_idx += 1); + + // Check if the indices are in Fortran order + let mut expected_index: Vec = vec![0; shape.len()]; + let mut expect_imag = false; + for (_, _, idx) in group.iter() { + if idx != &expected_index { + bail!("Stan returned data that was not in the expected order.") + } + if is_complex { + expect_imag = !expect_imag; + } + if !expect_imag { + // increment expected index + for i in 0..shape.len() { + if expected_index[i] < shape[i] - 1 { + expected_index[i] += 1; + break; + } else { + expected_index[i] = 0; + } + } + } + } + + Ok((shape, is_complex)) +} #[pymethods] impl StanModel { #[new] @@ -155,7 +252,9 @@ impl StanModel { let model = Arc::new( bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, ); - let variables = params(&model, true, true)?; + + let var_string = model.param_names(true, true); + let variables = params(var_string)?; let transform_adapter = transform_adapter.map(PyTransformAdapt::new); Ok(StanModel { model, @@ -556,7 +655,7 @@ impl Model for StanModel { fn math(&self) -> anyhow::Result> { Ok(CpuMath::new(StanDensity { inner: &self.model, - transform_adapter: self.transform_adapter.as_ref().map(|v| v.clone()), + transform_adapter: self.transform_adapter.clone(), })) } @@ -606,7 +705,6 @@ mod tests { 0., 6., 12., 18., 24., 2., 8., 14., 20., 26., 4., 10., 16., 22., 28., 1., 7., 13., 19., 25., 3., 9., 15., 21., 27., 5., 11., 17., 23., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); let data = vec![ @@ -619,7 +717,6 @@ mod tests { 0., 6., 12., 18., 24., 2., 8., 14., 20., 26., 4., 10., 16., 22., 28., 1., 7., 13., 19., 25., 3., 9., 15., 21., 27., 5., 11., 17., 23., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); let data = vec![ @@ -632,7 +729,358 @@ mod tests { 0., 15., 5., 20., 10., 25., 1., 16., 6., 21., 11., 26., 2., 17., 7., 22., 12., 27., 3., 18., 8., 23., 13., 28., 4., 19., 9., 24., 14., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); } + + #[test] + fn parse_vars() { + let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2"; + let parsed = super::params(vars).unwrap(); + assert!(parsed.len() == 1); + let parsed = parsed[0].clone(); + assert!(parsed.name == "x"); + assert!(parsed.shape == vec![3, 2]); + + // Incorrect order + let vars = "x.1.2,x.1.1,x.2.1,x.2.2,x.3.1,x.3.2"; + assert!(super::params(vars).is_err()); + + // Incorrect order + let vars = "x.1.2.real,x.1.2.imag"; + assert!(super::params(vars).is_err()); + + let vars = "x.1.1.real,x.1.1.imag,x.2.1.real,x.2.1.imag,x.3.1.real,x.3.1.imag"; + let parsed = super::params(vars).unwrap(); + assert!(parsed.len() == 2); + let var = parsed[0].clone(); + assert!(var.name == "x.real"); + assert!(var.shape == vec![3, 1]); + + let var = parsed[1].clone(); + assert!(var.name == "x.imag"); + assert!(var.shape == vec![3, 1]); + + // Test single variable + let vars = "alpha"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "alpha"); + assert_eq!(var.shape, Vec::::new()); + assert_eq!(var.size, 1); + + // Test multiple scalar variables + let vars = "alpha,beta,gamma"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 3); + assert_eq!(parsed[0].name, "alpha"); + assert_eq!(parsed[1].name, "beta"); + assert_eq!(parsed[2].name, "gamma"); + + // Test 1D array + let vars = "theta.1,theta.2,theta.3,theta.4"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "theta"); + assert_eq!(var.shape, vec![4]); + assert_eq!(var.size, 4); + + // Test variable name with colons and dots + let vars = "x:1:2.4:1.1,x:1:2.4:1.2,x:1:2.4:1.3"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "x:1:2.4:1"); + assert_eq!(var.shape, vec![3]); + assert_eq!(var.size, 3); + + let vars = " + a, + base, + base_i, + pair:1, + pair:2, + nested:1, + nested:2:1, + nested:2:2.real, + nested:2:2.imag, + arr_pair.1:1, + arr_pair.1:2, + arr_pair.2:1, + arr_pair.2:2, + arr_very_nested.1:1:1, + arr_very_nested.1:1:2:1, + arr_very_nested.1:1:2:2.real, + arr_very_nested.1:1:2:2.imag, + arr_very_nested.1:2, + arr_very_nested.2:1:1, + arr_very_nested.2:1:2:1, + arr_very_nested.2:1:2:2.real, + arr_very_nested.2:1:2:2.imag, + arr_very_nested.2:2, + arr_very_nested.3:1:1, + arr_very_nested.3:1:2:1, + arr_very_nested.3:1:2:2.real, + arr_very_nested.3:1:2:2.imag, + arr_very_nested.3:2, + arr_2d_pair.1.1:1, + arr_2d_pair.1.1:2, + arr_2d_pair.2.1:1, + arr_2d_pair.2.1:2, + arr_2d_pair.3.1:1, + arr_2d_pair.3.1:2, + arr_2d_pair.1.2:1, + arr_2d_pair.1.2:2, + arr_2d_pair.2.2:1, + arr_2d_pair.2.2:2, + arr_2d_pair.3.2:1, + arr_2d_pair.3.2:2, + basep1, + basep2, + basep3, + basep4, + basep5, + ultimate.1.1:1.1:1, + ultimate.1.1:1.1:2.1, + ultimate.1.1:1.1:2.2, + ultimate.1.1:1.2:1, + ultimate.1.1:1.2:2.1, + ultimate.1.1:1.2:2.2, + ultimate.1.1:2.1.1, + ultimate.1.1:2.2.1, + ultimate.1.1:2.3.1, + ultimate.1.1:2.4.1, + ultimate.1.1:2.1.2, + ultimate.1.1:2.2.2, + ultimate.1.1:2.3.2, + ultimate.1.1:2.4.2, + ultimate.1.1:2.1.3, + ultimate.1.1:2.2.3, + ultimate.1.1:2.3.3, + ultimate.1.1:2.4.3, + ultimate.1.1:2.1.4, + ultimate.1.1:2.2.4, + ultimate.1.1:2.3.4, + ultimate.1.1:2.4.4, + ultimate.1.1:2.1.5, + ultimate.1.1:2.2.5, + ultimate.1.1:2.3.5, + ultimate.1.1:2.4.5, + ultimate.2.1:1.1:1, + ultimate.2.1:1.1:2.1, + ultimate.2.1:1.1:2.2, + ultimate.2.1:1.2:1, + ultimate.2.1:1.2:2.1, + ultimate.2.1:1.2:2.2, + ultimate.2.1:2.1.1, + ultimate.2.1:2.2.1, + ultimate.2.1:2.3.1, + ultimate.2.1:2.4.1, + ultimate.2.1:2.1.2, + ultimate.2.1:2.2.2, + ultimate.2.1:2.3.2, + ultimate.2.1:2.4.2, + ultimate.2.1:2.1.3, + ultimate.2.1:2.2.3, + ultimate.2.1:2.3.3, + ultimate.2.1:2.4.3, + ultimate.2.1:2.1.4, + ultimate.2.1:2.2.4, + ultimate.2.1:2.3.4, + ultimate.2.1:2.4.4, + ultimate.2.1:2.1.5, + ultimate.2.1:2.2.5, + ultimate.2.1:2.3.5, + ultimate.2.1:2.4.5, + ultimate.1.2:1.1:1, + ultimate.1.2:1.1:2.1, + ultimate.1.2:1.1:2.2, + ultimate.1.2:1.2:1, + ultimate.1.2:1.2:2.1, + ultimate.1.2:1.2:2.2, + ultimate.1.2:2.1.1, + ultimate.1.2:2.2.1, + ultimate.1.2:2.3.1, + ultimate.1.2:2.4.1, + ultimate.1.2:2.1.2, + ultimate.1.2:2.2.2, + ultimate.1.2:2.3.2, + ultimate.1.2:2.4.2, + ultimate.1.2:2.1.3, + ultimate.1.2:2.2.3, + ultimate.1.2:2.3.3, + ultimate.1.2:2.4.3, + ultimate.1.2:2.1.4, + ultimate.1.2:2.2.4, + ultimate.1.2:2.3.4, + ultimate.1.2:2.4.4, + ultimate.1.2:2.1.5, + ultimate.1.2:2.2.5, + ultimate.1.2:2.3.5, + ultimate.1.2:2.4.5, + ultimate.2.2:1.1:1, + ultimate.2.2:1.1:2.1, + ultimate.2.2:1.1:2.2, + ultimate.2.2:1.2:1, + ultimate.2.2:1.2:2.1, + ultimate.2.2:1.2:2.2, + ultimate.2.2:2.1.1, + ultimate.2.2:2.2.1, + ultimate.2.2:2.3.1, + ultimate.2.2:2.4.1, + ultimate.2.2:2.1.2, + ultimate.2.2:2.2.2, + ultimate.2.2:2.3.2, + ultimate.2.2:2.4.2, + ultimate.2.2:2.1.3, + ultimate.2.2:2.2.3, + ultimate.2.2:2.3.3, + ultimate.2.2:2.4.3, + ultimate.2.2:2.1.4, + ultimate.2.2:2.2.4, + ultimate.2.2:2.3.4, + ultimate.2.2:2.4.4, + ultimate.2.2:2.1.5, + ultimate.2.2:2.2.5, + ultimate.2.2:2.3.5, + ultimate.2.2:2.4.5, + ultimate.1.3:1.1:1, + ultimate.1.3:1.1:2.1, + ultimate.1.3:1.1:2.2, + ultimate.1.3:1.2:1, + ultimate.1.3:1.2:2.1, + ultimate.1.3:1.2:2.2, + ultimate.1.3:2.1.1, + ultimate.1.3:2.2.1, + ultimate.1.3:2.3.1, + ultimate.1.3:2.4.1, + ultimate.1.3:2.1.2, + ultimate.1.3:2.2.2, + ultimate.1.3:2.3.2, + ultimate.1.3:2.4.2, + ultimate.1.3:2.1.3, + ultimate.1.3:2.2.3, + ultimate.1.3:2.3.3, + ultimate.1.3:2.4.3, + ultimate.1.3:2.1.4, + ultimate.1.3:2.2.4, + ultimate.1.3:2.3.4, + ultimate.1.3:2.4.4, + ultimate.1.3:2.1.5, + ultimate.1.3:2.2.5, + ultimate.1.3:2.3.5, + ultimate.1.3:2.4.5, + ultimate.2.3:1.1:1, + ultimate.2.3:1.1:2.1, + ultimate.2.3:1.1:2.2, + ultimate.2.3:1.2:1, + ultimate.2.3:1.2:2.1, + ultimate.2.3:1.2:2.2, + ultimate.2.3:2.1.1, + ultimate.2.3:2.2.1, + ultimate.2.3:2.3.1, + ultimate.2.3:2.4.1, + ultimate.2.3:2.1.2, + ultimate.2.3:2.2.2, + ultimate.2.3:2.3.2, + ultimate.2.3:2.4.2, + ultimate.2.3:2.1.3, + ultimate.2.3:2.2.3, + ultimate.2.3:2.3.3, + ultimate.2.3:2.4.3, + ultimate.2.3:2.1.4, + ultimate.2.3:2.2.4, + ultimate.2.3:2.3.4, + ultimate.2.3:2.4.4, + ultimate.2.3:2.1.5, + ultimate.2.3:2.2.5, + ultimate.2.3:2.3.5, + ultimate.2.3:2.4.5 + "; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed[0].name, "a"); + assert_eq!(parsed[0].shape, vec![0usize; 0]); + + assert_eq!(parsed[1].name, "base"); + assert_eq!(parsed[1].shape, vec![0usize; 0]); + + assert_eq!(parsed[2].name, "base_i"); + assert_eq!(parsed[2].shape, vec![0usize; 0]); + + assert_eq!(parsed[3].name, "pair:1"); + assert_eq!(parsed[3].shape, vec![0usize; 0]); + + assert_eq!(parsed[4].name, "pair:2"); + assert_eq!(parsed[4].shape, vec![0usize; 0]); + + assert_eq!(parsed[5].name, "nested:1"); + assert_eq!(parsed[5].shape, vec![0usize; 0]); + + assert_eq!(parsed[6].name, "nested:2:1"); + assert_eq!(parsed[6].shape, vec![0usize; 0]); + + assert_eq!(parsed[7].name, "nested:2:2.real"); + assert_eq!(parsed[7].shape, vec![0usize; 0]); + + assert_eq!(parsed[8].name, "nested:2:2.imag"); + assert_eq!(parsed[8].shape, vec![0usize; 0]); + + assert_eq!(parsed[9].name, "arr_pair.1:1"); + assert_eq!(parsed[9].shape, vec![0usize; 0]); + + assert_eq!(parsed[10].name, "arr_pair.1:2"); + assert_eq!(parsed[10].shape, vec![0usize; 0]); + + assert_eq!(parsed[11].name, "arr_pair.2:1"); + assert_eq!(parsed[11].shape, vec![0usize; 0]); + + assert_eq!(parsed[12].name, "arr_pair.2:2"); + assert_eq!(parsed[12].shape, vec![0usize; 0]); + + assert_eq!(parsed[13].name, "arr_very_nested.1:1:1"); + assert_eq!(parsed[13].shape, vec![0usize; 0]); + + assert_eq!(parsed[14].name, "arr_very_nested.1:1:2:1"); + assert_eq!(parsed[14].shape, vec![0usize; 0]); + + assert_eq!(parsed[15].name, "arr_very_nested.1:1:2:2.real"); + assert_eq!(parsed[15].shape, vec![0usize; 0]); + + assert_eq!(parsed[16].name, "arr_very_nested.1:1:2:2.imag"); + assert_eq!(parsed[16].shape, vec![0usize; 0]); + + assert_eq!(parsed[17].name, "arr_very_nested.1:2"); + assert_eq!(parsed[17].shape, vec![0usize; 0]); + + assert_eq!(parsed[18].name, "arr_very_nested.2:1:1"); + assert_eq!(parsed[18].shape, vec![0usize; 0]); + + assert_eq!(parsed[19].name, "arr_very_nested.2:1:2:1"); + assert_eq!(parsed[19].shape, vec![0usize; 0]); + + assert_eq!(parsed[20].name, "arr_very_nested.2:1:2:2.real"); + assert_eq!(parsed[20].shape, vec![0usize; 0]); + + assert_eq!(parsed[21].name, "arr_very_nested.2:1:2:2.imag"); + assert_eq!(parsed[21].shape, vec![0usize; 0]); + + assert_eq!(parsed[22].name, "arr_very_nested.2:2"); + assert_eq!(parsed[22].shape, vec![0usize; 0]); + + assert_eq!(parsed[23].name, "arr_very_nested.3:1:1"); + assert_eq!(parsed[23].shape, vec![0usize; 0]); + + assert_eq!(parsed[24].name, "arr_very_nested.3:1:2:1"); + assert_eq!(parsed[24].shape, vec![0usize; 0]); + + assert_eq!(parsed[25].name, "arr_very_nested.3:1:2:2.real"); + assert_eq!(parsed[25].shape, vec![0usize; 0]); + + assert_eq!(parsed[26].name, "arr_very_nested.3:1:2:2.imag"); + assert_eq!(parsed[26].shape, vec![0usize; 0]); + + assert_eq!(parsed[27].name, "arr_very_nested.3:2"); + assert_eq!(parsed[27].shape, vec![0usize; 0]); + } } diff --git a/tests/test_stan.py b/tests/test_stan.py index f7ba498..f44b755 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -64,6 +64,90 @@ def test_seed(): assert not np.allclose(trace.posterior.b[i], trace3.posterior.b[j]) +@pytest.mark.stan +def test_nested(): + # Adapted from + # https://github.com/stan-dev/stanio/blob/main/test/data/tuples/output.stan + model = """ + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + generated quantities { + real base = normal_rng(0, 1); + int base_i = to_int(normal_rng(10, 10)); + + tuple(real, real) pair = (base, base * 2); + + tuple(real, tuple(int, complex)) nested = (base * 3, (base_i, base * 4.0i)); + array[2] tuple(real, real) arr_pair = {pair, (base * 5, base * 6)}; + + array[3] tuple(tuple(real, tuple(int, complex)), real) arr_very_nested + = {(nested, base*7), ((base*8, (base_i*2, base*9.0i)), base * 10), (nested, base*11)}; + + array[3,2] tuple(real, real) arr_2d_pair = {{(base * 12, base * 13), (base * 14, base * 15)}, + {(base * 16, base * 17), (base * 18, base * 19)}, + {(base * 20, base * 21), (base * 22, base * 23)}}; + + real basep1 = base + 1, basep2 = base + 2; + real basep3 = base + 3, basep4 = base + 4, basep5 = base + 5; + array[2,3] tuple(array[2] tuple(real, vector[2]), matrix[4,5]) ultimate = + { + {( + {(base, [base *2, base *3]'), (base *4, [base*5, base*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * base + ), + ( + {(basep1, [basep1 *2, basep1 *3]'), (basep1 *4, [basep1*5, basep1*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep1 + ), + ( + {(basep2, [basep2 *2, basep2 *3]'), (basep2 *4, [basep2*5, basep2*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep2 + ) + }, + {( + {(basep3, [basep3 *2, basep3 *3]'), (basep3 *4, [basep3*5, basep3*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep3 + ), + ( + {(basep4, [basep4 *2, basep4 *3]'), (basep4 *4, [basep4*5, basep4*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep4 + ), + ( + {(basep5, [basep5 *2, basep5 *3]'), (basep5 *4, [basep5*5, basep5*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep5 + ) + }}; + } + """ + + compiled = nutpie.compile_stan_model(code=model) + tr = nutpie.sample(compiled, chains=6) + base = tr.posterior.base + + assert np.allclose(tr.posterior["nested:2:2.imag"], 4 * base) + assert np.allclose(tr.posterior["nested:2:2.real"], 0.0) + + assert np.allclose(tr.posterior["ultimate.1.1:1.1:1"], base) + assert np.allclose(tr.posterior["ultimate.1.2:1.1:1"], base + 1) + assert np.allclose(tr.posterior["ultimate.1.3:1.1:1"], base + 2) + assert np.allclose(tr.posterior["ultimate.2.1:1.1:1"], base + 3) + assert np.allclose(tr.posterior["ultimate.2.2:1.1:1"], base + 4) + assert np.allclose(tr.posterior["ultimate.2.3:1.1:1"], base + 5) + + assert tr.posterior["ultimate.2.1:1.1:2"].shape == (6, 1000, 2) + assert np.allclose( + tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 0], 2 * (base + 5) + ) + assert np.allclose( + tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 1], 3 * (base + 5) + ) + assert np.allclose(tr.posterior["base_i"], tr.posterior.base_i.astype(int)) + + @pytest.mark.stan def test_stan_model_data(): model = """ From ede466e4a006be07a45e075d0a0d72696e98ce45 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 7 May 2025 15:44:40 +0200 Subject: [PATCH 3/3] style: fix some clippy warnings --- src/progress.rs | 2 +- src/pyfunc.rs | 15 +++++++-------- src/pymc.rs | 4 ++-- src/stan.rs | 6 +++--- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/progress.rs b/src/progress.rs index 2403e15..2b130a4 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -49,7 +49,7 @@ impl ProgressHandler { let progress = progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); let rendered = template.render_from(&self.engine, &progress).to_string(); - let rendered = rendered.unwrap_or_else(|err| format!("{}", err)); + let rendered = rendered.unwrap_or_else(|err| format!("{err}")); let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); progress_update_count += 1; }; diff --git a/src/pyfunc.rs b/src/pyfunc.rs index f23145e..de3ecd9 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -127,9 +127,8 @@ impl LogpError for PyLogpError { let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; - return attr - .is_truthy() - .expect("Could not access is_recoverable in error check"); + attr.is_truthy() + .expect("Could not access is_recoverable in error check") }), Self::ReturnTypeError() => false, Self::NotContiguousError(_) => false, @@ -151,7 +150,7 @@ impl PyDensity { transform_adapter: Option<&PyTransformAdapt>, ) -> Result { let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; - let transform_adapter = transform_adapter.map(|val| val.clone()); + let transform_adapter = transform_adapter.cloned(); Ok(Self { logp: logp_func, transform_adapter, @@ -185,7 +184,7 @@ impl CpuLogpFunc for PyDensity { ); Ok(logp_val) } - Err(err) => return Err(PyLogpError::PyError(err)), + Err(err) => Err(PyLogpError::PyError(err)), } }) } @@ -359,7 +358,7 @@ impl TensorShape { Self { shape, dims, size } } pub fn size(&self) -> usize { - return self.size; + self.size } } @@ -617,14 +616,14 @@ impl Model for PyModel { settings: &'model S, ) -> Result> { let draws = settings.hint_num_tune() + settings.hint_num_draws(); - Ok(PyTrace::new( + PyTrace::new( rng, chain_id, self.variables.clone(), &self.make_expand_func, draws, ) - .context("Could not create PyTrace object")?) + .context("Could not create PyTrace object") } fn math(&self) -> Result> { diff --git a/src/pymc.rs b/src/pymc.rs index b33b821..685b768 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -112,7 +112,7 @@ impl LogpError for ErrorCode { } } -impl<'a> CpuLogpFunc for &'a LogpFunc { +impl CpuLogpFunc for &LogpFunc { type LogpError = ErrorCode; type TransformParams = (); @@ -175,7 +175,7 @@ impl<'model> DrawStorage for PyMcTrace<'model> { let num_arrays = data.len() / size; let data = Float64Array::from(data); let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let offsets = OffsetBuffer::from_lengths((0..num_arrays).into_iter().map(|_| size)); + let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); let field = Field::new(name, DataType::LargeList(item_field), false); (Arc::new(field), Arc::new(array) as Arc) diff --git a/src/stan.rs b/src/stan.rs index f1e6d54..0586e6f 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -137,7 +137,7 @@ fn params(var_string: &str) -> anyhow::Result> { for (name, group) in &parsed_variables?.iter().chunk_by(|(name, _, _)| name) { // Find maximum shape and check if this is a complex variable let (shape, is_complex) = determine_variable_shape(group) - .context(format!("Error while parsing stan variable {}", name))?; + .context(format!("Error while parsing stan variable {name}"))?; // Calculate total size of this variable let size = shape.iter().product(); @@ -146,7 +146,7 @@ fn params(var_string: &str) -> anyhow::Result> { // Create Parameter objects (one for real and one for imag if complex) if is_complex { variables.push(Parameter { - name: format!("{}.real", name), + name: format!("{name}.real"), shape: shape.clone(), size, start_idx, @@ -155,7 +155,7 @@ fn params(var_string: &str) -> anyhow::Result> { start_idx = end_idx; end_idx = start_idx + size; variables.push(Parameter { - name: format!("{}.imag", name), + name: format!("{name}.imag"), shape, size, start_idx,