From 16281a87864cc633cf18dcb962b68de67a53632f Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 May 2025 21:32:29 +0200 Subject: [PATCH 1/5] fix: allow variables with zero shapes --- src/pymc.rs | 17 +++++++++++++---- src/stan.rs | 13 ++++++++++++- tests/test_pymc.py | 17 +++++++++++++++++ tests/test_stan.py | 19 +++++++++++++++++++ 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/pymc.rs b/src/pymc.rs index 685b768..220ad5f 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -15,6 +15,7 @@ use pyo3::{ Bound, Py, PyAny, PyObject, PyResult, Python, }; +use rand_distr::num_traits::CheckedEuclid; use thiserror::Error; type UserData = *const std::ffi::c_void; @@ -128,8 +129,8 @@ impl CpuLogpFunc for &LogpFunc { let retcode = unsafe { (self.func)( self.dim, - &position[0] as *const f64, - &mut gradient[0] as *mut f64, + position.as_ptr(), + gradient.as_mut_ptr(), logp_ptr, self.user_data_ptr, ) @@ -148,6 +149,7 @@ pub(crate) struct PyMcTrace<'model> { var_sizes: Vec, var_names: Vec, expand: &'model ExpandFunc, + count: usize, } impl<'model> DrawStorage for PyMcTrace<'model> { @@ -165,14 +167,20 @@ impl<'model> DrawStorage for PyMcTrace<'model> { data.extend_from_slice(vals); start = end; } + self.count += 1; + Ok(()) } fn finalize(self) -> Result> { let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes) .map(|(data, name, size)| { - assert!(data.len() % size == 0); - let num_arrays = data.len() / size; + let (num_arrays, rem) = data + .len() + .checked_div_rem_euclid(&size) + .unwrap_or((self.count, 0)); + assert!(rem == 0); + assert!(num_arrays == self.count); let data = Float64Array::from(data); let item_field = Arc::new(Field::new("item", DataType::Float64, false)); let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); @@ -206,6 +214,7 @@ impl<'model> PyMcTrace<'model> { var_sizes: model.var_sizes.clone(), var_names: model.var_names.clone(), expand: &model.expand, + count: 0, } } diff --git a/src/stan.rs b/src/stan.rs index 0586e6f..b10ac44 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -86,6 +86,9 @@ pub struct StanModel { /// Return meta information about the constrained parameters of the model fn params(var_string: &str) -> anyhow::Result> { + if var_string.is_empty() { + return Ok(vec![]); + } // Parse each variable string into (name, is_complex, indices) let parsed_variables: anyhow::Result)>> = var_string .split(',') @@ -540,6 +543,7 @@ pub struct StanTrace<'model> { trace: Vec>, expanded_buffer: Box<[f64]>, rng: bridgestan::Rng<&'model bridgestan::StanLibrary>, + count: usize, } impl<'model> Clone for StanTrace<'model> { @@ -559,6 +563,7 @@ impl<'model> Clone for StanTrace<'model> { trace: self.trace.clone(), expanded_buffer: self.expanded_buffer.clone(), rng, + count: self.count, } } } @@ -591,6 +596,7 @@ impl<'model> DrawStorage for StanTrace<'model> { // We need to transpose fortran_to_c_order(slice, &var.shape, trace); } + self.count += 1; Ok(()) } @@ -613,7 +619,7 @@ impl<'model> DrawStorage for StanTrace<'model> { .unzip(); Ok(Arc::new( - StructArray::try_new(fields.into(), arrays, None) + StructArray::try_new_with_length(fields.into(), arrays, None, self.count) .context("Could not create arrow StructArray")?, )) } @@ -649,6 +655,7 @@ impl Model for StanModel { trace, rng, expanded_buffer: buffer.into(), + count: 0, }) } @@ -734,6 +741,10 @@ mod tests { #[test] fn parse_vars() { + let vars = ""; + let parsed = super::params(vars).unwrap(); + assert!(parsed.len() == 0); + let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2"; let parsed = super::params(vars).unwrap(); assert!(parsed.len() == 1); diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 5c5d2b3..729b644 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -31,6 +31,23 @@ def test_pymc_model(backend, gradient_backend): trace.posterior.a # noqa: B018 +@pytest.mark.pymc +@parameterize_backends +def test_zero_size(backend, gradient_backend): + import pytensor.tensor as pt + + with pm.Model() as model: + a = pm.Normal("a", shape=(0, 0, 10)) + pm.Deterministic("b", pt.exp(a)) + + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + trace = nutpie.sample(compiled, chains=1, draws=17, tune=100) + assert trace.posterior.a.shape == (1, 17, 0, 0, 10) + assert trace.posterior.b.shape == (1, 17, 0, 0, 10) + + @pytest.mark.pymc @parameterize_backends def test_pymc_model_float32(backend, gradient_backend): diff --git a/tests/test_stan.py b/tests/test_stan.py index da84fc9..76b643e 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -27,6 +27,25 @@ def test_stan_model(): trace.posterior.a # noqa: B018 +@pytest.mark.stan +def test_empty(): + model = """ + data {} + parameters { + array[0] real a; + } + model { + a ~ normal(0, 1); + } + """ + + compiled_model = nutpie.compile_stan_model(code=model) + trace = nutpie.sample(compiled_model) # noqa: F841 + # TODO: Variable `a` is missing because of this bridgestan issue: + # https://github.com/roualdes/bridgestan/issues/278 + # assert trace.posterior.a.shape == (0, 1000) + + @pytest.mark.stan def test_seed(): model = """ From 15f93e9ebd51d99877a2e1bedb979f4401cf747e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 May 2025 21:32:29 +0200 Subject: [PATCH 2/5] fix: let rust sampler decide on default num chains --- python/nutpie/sample.py | 85 ++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 356b8d0..bbb7b9c 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -455,18 +455,42 @@ def _repr_html_(self): def sample( compiled_model: CompiledModel, *, - draws: int | None, - tune: int | None, - chains: int, - cores: Optional[int], - seed: Optional[int], - save_warmup: bool, - progress_bar: bool, + draws: int | None = None, + tune: int | None = None, + chains: int | None = None, + cores: int | None = None, + seed: int | None = None, + save_warmup: bool = True, + progress_bar: bool = True, + low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, + init_mean: np.ndarray | None = None, + return_raw_trace: bool = False, + progress_template: str | None = None, + progress_style: str | None = None, + progress_rate: int = 100, +) -> arviz.InferenceData: ... + + +@overload +def sample( + compiled_model: CompiledModel, + *, + draws: int | None = None, + tune: int | None = None, + chains: int | None = None, + cores: int | None = None, + seed: int | None = None, + save_warmup: bool = True, + progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, transform_adapt: bool = False, - init_mean: Optional[np.ndarray], - return_raw_trace: bool, + init_mean: np.ndarray | None = None, + return_raw_trace: bool = False, blocking: Literal[True], + progress_template: str | None = None, + progress_style: str | None = None, + progress_rate: int = 100, **kwargs, ) -> arviz.InferenceData: ... @@ -475,18 +499,21 @@ def sample( def sample( compiled_model: CompiledModel, *, - draws: int | None, - tune: int | None, - chains: int, - cores: Optional[int], - seed: Optional[int], - save_warmup: bool, - progress_bar: bool, + draws: int | None = None, + tune: int | None = None, + chains: int | None = None, + cores: int | None = None, + seed: int | None = None, + save_warmup: bool = True, + progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, transform_adapt: bool = False, - init_mean: Optional[np.ndarray], - return_raw_trace: bool, + init_mean: np.ndarray | None = None, + return_raw_trace: bool = False, blocking: Literal[False], + progress_template: str | None = None, + progress_style: str | None = None, + progress_rate: int = 100, **kwargs, ) -> _BackgroundSampler: ... @@ -496,21 +523,21 @@ def sample( *, draws: int | None = None, tune: int | None = None, - chains: int = 6, - cores: Optional[int] = None, - seed: Optional[int] = None, + chains: int | None = None, + cores: int | None = None, + seed: int | None = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, transform_adapt: bool = False, - init_mean: Optional[np.ndarray] = None, + init_mean: np.ndarray | None = None, return_raw_trace: bool = False, blocking: bool = True, - progress_template: Optional[str] = None, - progress_style: Optional[str] = None, + progress_template: str | None = None, + progress_style: str | None = None, progress_rate: int = 100, **kwargs, -) -> arviz.InferenceData: +) -> arviz.InferenceData | _BackgroundSampler: """Sample the posterior distribution for a compiled model. Parameters @@ -618,7 +645,8 @@ def sample( settings.num_tune = tune if draws is not None: settings.num_draws = draws - settings.num_chains = chains + if chains is not None: + settings.num_chains = chains for name, val in kwargs.items(): setattr(settings, name, val) @@ -629,7 +657,10 @@ def sample( available = os.process_cpu_count() # type: ignore except AttributeError: available = os.cpu_count() - cores = min(chains, cast(int, available)) + if chains is None: + cores = available + else: + cores = min(chains, cast(int, available)) if init_mean is None: init_mean = np.zeros(compiled_model.n_dim) From 42721a95b75585a49380f0878bd59c63b38f19a3 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 May 2025 21:32:29 +0200 Subject: [PATCH 3/5] chore: update dependencies --- Cargo.lock | 617 +++++++++++++++++++++++------------------------------ Cargo.toml | 13 +- 2 files changed, 275 insertions(+), 355 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e3a940c..a3aa5e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21,16 +21,16 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.2.15", + "getrandom 0.3.3", "once_cell", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -71,15 +71,15 @@ checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "arrow" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc208515aa0151028e464cc94a692156e945ce5126abd3537bb7fd6ba2143ed1" +checksum = "b1bb018b6960c87fd9d025009820406f74e83281185a8bdcb44880d2aa5c9a87" dependencies = [ "arrow-arith", "arrow-array", @@ -95,9 +95,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e07e726e2b3f7816a85c6a45b6ec118eeeabf0b2a8c208122ad949437181f49a" +checksum = "44de76b51473aa888ecd6ad93ceb262fb8d40d1f1154a4df2f069b3590aa7575" dependencies = [ "arrow-array", "arrow-buffer", @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2262eba4f16c78496adfd559a29fe4b24df6088efc9985a873d58e92be022d5" +checksum = "29ed77e22744475a9a53d00026cf8e166fe73cf42d89c4c4ae63607ee1cfcc3f" dependencies = [ "ahash", "arrow-buffer", @@ -125,9 +125,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e899dade2c3b7f5642eb8366cfd898958bcca099cde6dfea543c7e8d3ad88d4" +checksum = "b0391c96eb58bf7389171d1e103112d3fc3e5625ca6b372d606f2688f1ea4cce" dependencies = [ "bytes", "half", @@ -136,9 +136,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4103d88c5b441525ed4ac23153be7458494c2b0c9a11115848fdb9b81f6f886a" +checksum = "f39e1d774ece9292697fcbe06b5584401b26bd34be1bec25c33edae65c2420ff" dependencies = [ "arrow-array", "arrow-buffer", @@ -156,9 +156,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a329fb064477c9ec5f0870d2f5130966f91055c7c5bce2b3a084f116bc28c3b" +checksum = "cf75ac27a08c7f48b88e5c923f267e980f27070147ab74615ad85b5c5f90473d" dependencies = [ "arrow-buffer", "arrow-schema", @@ -168,9 +168,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f841bfcc1997ef6ac48ee0305c4dfceb1f7c786fe31e67c1186edf775e1f1160" +checksum = "ab2f1065a5cad7b9efa9e22ce5747ce826aa3855766755d4904535123ef431e7" dependencies = [ "arrow-array", "arrow-buffer", @@ -181,9 +181,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1eeb55b0a0a83851aa01f2ca5ee5648f607e8506ba6802577afdda9d75cdedcd" +checksum = "3703a0e3e92d23c3f756df73d2dc9476873f873a76ae63ef9d3de17fda83b2d8" dependencies = [ "arrow-array", "arrow-buffer", @@ -194,18 +194,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85934a9d0261e0fa5d4e2a5295107d743b543a6e0484a835d4b8db2da15306f9" +checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", ] [[package]] name = "arrow-select" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e2932aece2d0c869dd2125feb9bd1709ef5c445daa3838ac4112dcfa0fda52c" +checksum = "24b7b85575702b23b85272b01bc1c25a01c9b9852305e5d0078c79ba25d995d4" dependencies = [ "ahash", "arrow-array", @@ -217,9 +217,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "54.2.1" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "912e38bd6a7a7714c1d9b61df80315685553b7455e8a6045c27531d8ecd5b458" +checksum = "9260fddf1cdf2799ace2b4c2fc0356a9789fa7551e0953e35435536fecefebbd" dependencies = [ "arrow-array", "arrow-buffer", @@ -281,9 +281,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "block-buffer" @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "bridgestan" -version = "2.6.1" +version = "2.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcc63930a0b64b9c4ca744d32aa2c664bcc4fdd34bd68778e531572318cd0bc" +checksum = "3fcf23cdd20237d4699464b803c6aef49f547266514c7361c27b25875ee69298" dependencies = [ "bindgen", "libloading", @@ -315,23 +315,9 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.8.1" +version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "9134a6ef01ce4b366b50689c94f82c14bc72bc5d0386829828a2e2752ef7958c" [[package]] name = "byteorder" @@ -373,9 +359,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.16" +version = "1.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" +checksum = "16595d3be041c03b09d08d0858631facccee9221e579704070e6e9e4915d3bc7" dependencies = [ "jobserver", "libc", @@ -399,14 +385,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.39" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets", + "windows-link", ] [[package]] @@ -459,18 +445,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.32" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83" +checksum = "ed93b9805f8ba930df42c2590f05453d5ec36cbb85d018868a5b24d31f6ac000" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.32" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8" +checksum = "379026ff283facf611b0ea629334361c4211d1b12ee01024eec1591133b04120" dependencies = [ "anstyle", "clap_lex", @@ -510,7 +496,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] @@ -547,25 +533,22 @@ dependencies = [ [[package]] name = "criterion" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", - "is-terminal", - "itertools 0.10.5", + "itertools 0.13.0", "num-traits", - "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "walkdir", @@ -663,18 +646,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" -[[package]] -name = "enum-as-inner" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "equator" version = "0.2.2" @@ -717,9 +688,9 @@ dependencies = [ [[package]] name = "faer" -version = "0.21.9" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebe9ac2a073e05ca749eeea503fae16a91440b20d2e92b6fc6f6c6919b9964eb" +checksum = "49fce40ad65c366fbc6cd70a99d09d1008f075280bf2455e558e163c82913a9f" dependencies = [ "bytemuck", "dyn-stack", @@ -730,7 +701,6 @@ dependencies = [ "generativity", "libm", "nano-gemm", - "npyz", "num-complex", "num-traits", "pulp", @@ -750,9 +720,9 @@ dependencies = [ [[package]] name = "faer-traits" -version = "0.21.5" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1430e111b20872c7eaa82c7ada071bff1c3e3ac09cc6f4df676065fd2d41eb62" +checksum = "54febfcbb90edaab562d85447a94d500f1601f11db0b30d27da87ed6542c8f91" dependencies = [ "bytemuck", "dyn-stack", @@ -762,14 +732,15 @@ dependencies = [ "num-complex", "num-traits", "pulp", + "qd", "reborrow", ] [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", "miniz_oxide", @@ -785,7 +756,6 @@ dependencies = [ "gemm-c32", "gemm-c64", "gemm-common", - "gemm-f16", "gemm-f32", "gemm-f64", "num-complex", @@ -833,7 +803,6 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" dependencies = [ "bytemuck", "dyn-stack", - "half", "libm", "num-complex", "num-traits", @@ -842,24 +811,6 @@ dependencies = [ "pulp", "raw-cpuid", "seq-macro", - "sysctl", -] - -[[package]] -name = "gemm-f16" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" -dependencies = [ - "dyn-stack", - "gemm-common", - "gemm-f32", - "half", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", ] [[package]] @@ -910,9 +861,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -921,9 +872,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "libc", @@ -939,11 +890,10 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ - "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -951,9 +901,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" [[package]] name = "heck" @@ -961,12 +911,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" - [[package]] name = "hmac" version = "0.12.1" @@ -978,14 +922,15 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", "windows-core", ] @@ -1027,17 +972,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "is-terminal" -version = "0.4.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys", -] - [[package]] name = "itertools" version = "0.10.5" @@ -1073,10 +1007,11 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -1162,37 +1097,37 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libloading" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets", + "windows-targets 0.53.0", ] [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "log" -version = "0.4.26" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "matrixmultiply" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" dependencies = [ "autocfg", "rawpointer", @@ -1221,40 +1156,18 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] -[[package]] -name = "multiversion" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edb7f0ff51249dfda9ab96b5823695e15a052dc15074c9dbf3d118afaf2c201" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b093064383341eb3271f42e381cb8f10a01459478446953953c75d24bd339fc0" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "target-features", -] - [[package]] name = "nano-gemm" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f563548d38f390ef9893e4883ec38c1fb312f569e98d76bededdd91a3b41a043" +checksum = "bb5ba2bea1c00e53de11f6ab5bd0761ba87dc0045d63b0c87ee471d2d3061376" dependencies = [ "equator 0.2.2", "nano-gemm-c32", @@ -1345,17 +1258,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "npyz" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f27ea175875c472b3df61ece89a6d6ef4e0627f43704e400c782f174681ebd" -dependencies = [ - "byteorder", - "num-bigint", - "py_literal", -] - [[package]] name = "num" version = "0.4.3" @@ -1445,9 +1347,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "numpy" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7cfbf3f0feededcaa4d289fe3079b03659e85c5b5a177f4ba6fb01ab4fb3e39" +checksum = "29f1dee9aa8d3f6f8e8b9af3803006101bb3653866ef056d530d53ae68587191" dependencies = [ "libc", "ndarray", @@ -1461,7 +1363,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.14.3" +version = "0.15.0" dependencies = [ "anyhow", "arrow", @@ -1472,7 +1374,7 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", - "rand 0.9.0", + "rand 0.9.1", "rand_chacha 0.9.0", "rand_distr", "rayon", @@ -1485,17 +1387,16 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.15.1" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d3052cf8ae044673a4bb41819943e62e43af7c1443f45c6e2f8c895e9fa994" +checksum = "acad2be84df0d14341d8de7d30c1019ecc008f4722befbd45745092a918c0a02" dependencies = [ "anyhow", "arrow", "faer", "itertools 0.14.0", - "multiversion", "pulp", - "rand 0.9.0", + "rand 0.9.1", "rand_chacha 0.9.0", "rand_distr", "rayon", @@ -1504,9 +1405,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.1" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" @@ -1561,51 +1462,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "pest" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" -dependencies = [ - "memchr", - "thiserror 2.0.12", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" -dependencies = [ - "pest", - "pest_generator", -] - -[[package]] -name = "pest_generator" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "pest_meta" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" -dependencies = [ - "once_cell", - "pest", - "sha2", -] - [[package]] name = "pkg-config" version = "0.3.32" @@ -1667,14 +1523,14 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.8.23", + "zerocopy", ] [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", "syn", @@ -1682,18 +1538,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] [[package]] name = "pulp" -version = "0.21.4" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95fb7a99b37aaef4c7dd2fd15a819eb8010bfc7a2c2155230d51f497316cad6d" +checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907" dependencies = [ "bytemuck", "cfg-if", @@ -1703,27 +1559,13 @@ dependencies = [ "version_check", ] -[[package]] -name = "py_literal" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" -dependencies = [ - "num-bigint", - "num-complex", - "num-traits", - "pest", - "pest_derive", -] - [[package]] name = "pyo3" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229" +checksum = "f239d656363bcee73afef85277f1b281e8ac6212a1d42aa90e55b90ed43c47a4" dependencies = [ "anyhow", - "cfg-if", "indoc", "libc", "memoffset", @@ -1737,9 +1579,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1" +checksum = "755ea671a1c34044fa165247aaf6f419ca39caa6003aee791a0df2713d8f1b6d" dependencies = [ "once_cell", "target-lexicon", @@ -1747,9 +1589,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc" +checksum = "fc95a2e67091e44791d4ea300ff744be5293f394f1bafd9f78c080814d35956e" dependencies = [ "libc", "pyo3-build-config", @@ -1757,9 +1599,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44" +checksum = "a179641d1b93920829a62f15e87c0ed791b6c8db2271ba0fd7c2686090510214" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1769,9 +1611,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855" +checksum = "9dff85ebcaab8c441b0e3f7ae40a6963ecea8a9f5e74f647e33fcf5ec9a1e89e" dependencies = [ "heck", "proc-macro2", @@ -1780,6 +1622,18 @@ dependencies = [ "syn", ] +[[package]] +name = "qd" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73940173cf92cd24f3650f5f388946524026712a6ca170762340acf5fb3fde0f" +dependencies = [ + "bytemuck", + "libm", + "num-traits", + "pulp", +] + [[package]] name = "quote" version = "1.0.40" @@ -1808,13 +1662,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.23", ] [[package]] @@ -1843,7 +1696,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -1852,7 +1705,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.2", + "getrandom 0.3.3", ] [[package]] @@ -1862,7 +1715,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.0", + "rand 0.9.1", ] [[package]] @@ -1943,9 +1796,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "ryu" @@ -2023,9 +1876,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -2040,9 +1893,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "static_assertions" @@ -2058,35 +1911,15 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] -[[package]] -name = "sysctl" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" -dependencies = [ - "bitflags", - "byteorder", - "enum-as-inner", - "libc", - "thiserror 1.0.69", - "walkdir", -] - -[[package]] -name = "target-features" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" - [[package]] name = "target-lexicon" version = "0.13.2" @@ -2152,9 +1985,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.40" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d9c75b47bdff86fa3334a3db91356b8d7d86a9b839dab7d0bdc5c3d3a077618" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "num-conv", @@ -2212,12 +2045,6 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - [[package]] name = "unicode-ident" version = "1.0.18" @@ -2362,11 +2189,61 @@ dependencies = [ [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ - "windows-targets", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", ] [[package]] @@ -2375,7 +2252,7 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2384,14 +2261,30 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] @@ -2400,48 +2293,96 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "wit-bindgen-rt" version = "0.39.0" @@ -2453,38 +2394,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd97444d05a4328b90e75e503a34bad781f14e28a823ad3557f0750df1ebcbc6" -dependencies = [ - "zerocopy-derive 0.8.23", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn", + "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.23" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6352c01d0edd5db859a63e2605f4ea3183ddbd15e2c4a9e7d32184df75e4f154" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", @@ -2532,9 +2453,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.14+zstd.1.5.7" +version = "2.0.15+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index dfc640f..84081ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.14.3" +version = "0.15.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -15,21 +15,20 @@ rust-version = "1.76" [features] extension-module = ["pyo3/extension-module"] default = ["extension-module"] -simd_support = ["nuts-rs/simd_support"] [lib] name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.15.1" -numpy = "0.24.0" +nuts-rs = "0.16.1" +numpy = "0.25.0" rand = "0.9.0" thiserror = "2.0.3" rand_chacha = "0.9.0" rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "54.2.0", default-features = false, features = ["ffi"] } +arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.14.0" bridgestan = "2.6.1" @@ -41,11 +40,11 @@ indicatif = "0.17.8" tch = { version = "0.20.0", optional = true } [dependencies.pyo3] -version = "0.24.1" +version = "0.25.0" features = ["extension-module", "anyhow"] [dev-dependencies] -criterion = "0.5.1" +criterion = "0.6.0" [profile.release] lto = "fat" From 2de993e03b39a1a67f8d5f553cc2530153355269 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 May 2025 21:32:29 +0200 Subject: [PATCH 4/5] test: add low rank tests --- tests/test_pymc.py | 26 ++++++++++++++++++++++++++ tests/test_stan.py | 19 ++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 729b644..51e2159 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -31,6 +31,32 @@ def test_pymc_model(backend, gradient_backend): trace.posterior.a # noqa: B018 +@pytest.mark.pymc +@parameterize_backends +def test_low_rank(backend, gradient_backend): + with pm.Model() as model: + pm.Normal("a") + + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) + trace.posterior.a # noqa: B018 + + +@pytest.mark.pymc +@parameterize_backends +def test_low_rank_half_normal(backend, gradient_backend): + with pm.Model() as model: + pm.HalfNormal("a", shape=13) + + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) + trace.posterior.a # noqa: B018 + + @pytest.mark.pymc @parameterize_backends def test_zero_size(backend, gradient_backend): diff --git a/tests/test_stan.py b/tests/test_stan.py index 76b643e..53b6b40 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -27,6 +27,23 @@ def test_stan_model(): trace.posterior.a # noqa: B018 +@pytest.mark.stan +def test_stan_model_low_rank(): + model = """ + data {} + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + """ + + compiled_model = nutpie.compile_stan_model(code=model) + trace = nutpie.sample(compiled_model, low_rank_modified_mass_matrix=True) + trace.posterior.a # noqa: B018 + + @pytest.mark.stan def test_empty(): model = """ @@ -40,7 +57,7 @@ def test_empty(): """ compiled_model = nutpie.compile_stan_model(code=model) - trace = nutpie.sample(compiled_model) # noqa: F841 + nutpie.sample(compiled_model) # TODO: Variable `a` is missing because of this bridgestan issue: # https://github.com/roualdes/bridgestan/issues/278 # assert trace.posterior.a.shape == (0, 1000) From e7eb9a5488578197808d5d9d544303584e46e086 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 May 2025 21:32:29 +0200 Subject: [PATCH 5/5] chore: update changelog --- CHANGELOG.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8eab10f..a07cb81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,63 @@ All notable changes to this project will be documented in this file. +## [0.15.0] - 2025-05-27 + +### Bug Fixes + +- Use stanio for creating Stan's data JSON (#205) (Brian Ward) + +- Rng for generated quantities (Adrian Seyboldt) + +- Correctly handle tuples in stan traces (Adrian Seyboldt) + +- Allow variables with zero shapes (Adrian Seyboldt) + +- Let rust sampler decide on default num chains (Adrian Seyboldt) + + +### Documentation + +- Fix section links path (Guspan Tanadi) + +- Link to website (Adrian Seyboldt) + + +### Features + +- Improvements to normalizing flow (Adrian Seyboldt) + +- Experiment with planar flows (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Bump pyo3 in the cargo group across 1 directory (dependabot[bot]) + +- Bump astral-sh/setup-uv from 5 to 6 (#203) (dependabot[bot]) + +- Add entries to gitignore (Adrian Seyboldt) + +- Bump dependencies (Adrian Seyboldt) + + +### Styling + +- Fix some clippy warnings (Adrian Seyboldt) + + +### Testing + +- Check that normalizing flows are reproducible (Adrian Seyboldt) + +- Add low rank tests (Adrian Seyboldt) + + +### Build + +- Increase optimization level (Adrian Seyboldt) + + ## [0.14.3] - 2025-03-18 ### Bug Fixes @@ -11,6 +68,11 @@ All notable changes to this project will be documented in this file. - Better initialization of masked flows (Adrian Seyboldt) +### Documentation + +- Fix spelling and grammar (Daniel Saunders) + + ### Features - Add masked coupling flow (Adrian Seyboldt)