Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@

All notable changes to this project will be documented in this file.

## [0.17.1] - 2025-11-13

### Bug Fixes

- Store step size info in transform_adapt_strategy (Adrian Seyboldt)

- Mindepth when check_turning=True was misbehaving (Adrian Seyboldt)


### Features

- Support datetime coordinates (Adrian Seyboldt)


### Miscellaneous Tasks

- Update dependencies (Adrian Seyboldt)


## [0.17.0] - 2025-10-08

### Bug Fixes
Expand Down Expand Up @@ -52,6 +71,10 @@ All notable changes to this project will be documented in this file.

- Update dependencies (Adrian Seyboldt)

- Prepare 0.17.0 (Adrian Seyboldt)

- Correctly specify dependencies in workspace (Adrian Seyboldt)


### Performance

Expand Down
12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "nuts-rs"
version = "0.17.0"
version = "0.17.1"
authors = [
"Adrian Seyboldt <adrian.seyboldt@gmail.com>",
"PyMC Developers <pymc.devs@gmail.com>",
Expand Down Expand Up @@ -33,9 +33,9 @@ zarrs = { version = "0.22.0", features = [
"sharding",
"async",
], optional = true }
ndarray = { version = "0.16.1", optional = true }
arrow = { version = "56.2.0", optional = true }
arrow-schema = { version = "56.2.0", features = [
ndarray = { version = "0.17.1", optional = true }
arrow = { version = "57.0.0", optional = true }
arrow-schema = { version = "57.0.0", features = [
"canonical_extension_types",
], optional = true }
nuts-derive = { path = "./nuts-derive", version = "0.1.0" }
Expand All @@ -50,9 +50,9 @@ pretty_assertions = "1.4.0"
criterion = "0.7.0"
nix = { version = "0.30.0", features = ["sched"] }
approx = "0.5.1"
equator = "0.4.2"
equator = "0.4.0"
serde_json = "1.0"
ndarray = "0.16.1"
ndarray = "0.17.1"
tempfile = "3.0"
zarrs_object_store = "0.5.0"
object_store = "0.12.0"
Expand Down
12 changes: 12 additions & 0 deletions nuts-storable/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
use std::collections::HashMap;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DateTimeUnit {
Seconds,
Milliseconds,
Microseconds,
Nanoseconds,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ItemType {
U64,
Expand All @@ -8,6 +16,8 @@ pub enum ItemType {
F32,
Bool,
String,
DateTime64(DateTimeUnit),
TimeDelta64(DateTimeUnit),
}

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -18,6 +28,8 @@ pub enum Value {
F32(Vec<f32>),
Bool(Vec<bool>),
ScalarString(String),
DateTime64(DateTimeUnit, Vec<i64>),
TimeDelta64(DateTimeUnit, Vec<i64>),
ScalarU64(u64),
ScalarI64(i64),
ScalarF64(f64),
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ mod transform_adapt_strategy;
mod transformed_hamiltonian;

pub use nuts_derive::Storable;
pub use nuts_storable::{HasDims, ItemType, Storable, Value};
pub use nuts_storable::{DateTimeUnit, HasDims, ItemType, Storable, Value};

pub use adapt_strategy::EuclideanAdaptOptions;
pub use chain::Chain;
Expand Down
42 changes: 34 additions & 8 deletions src/nuts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ pub struct NutsOptions {
pub store_divergences: bool,
}

impl Default for NutsOptions {
fn default() -> Self {
NutsOptions {
maxdepth: 10,
mindepth: 0,
store_gradient: false,
store_unconstrained: false,
check_turning: true,
store_divergences: false,
}
}
}

pub(crate) fn draw<M, H, R, C>(
math: &mut M,
init: &mut State<M, H::Point>,
Expand All @@ -282,18 +295,31 @@ where
return Ok((init.clone(), info));
}

let options_no_check = NutsOptions {
check_turning: false,
..*options
};

while tree.depth < options.maxdepth {
let direction: Direction = rng.random();
tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) {
let current_options = if tree.depth < options.mindepth {
&options_no_check
} else {
options
};
tree = match tree.extend(
math,
rng,
hamiltonian,
direction,
collector,
current_options,
) {
ExtendResult::Ok(tree) => tree,
ExtendResult::Turning(tree) => {
if tree.depth < options.mindepth {
tree
} else {
let info = tree.info(false, None);
collector.register_draw(math, &tree.draw, &info);
return Ok((tree.draw, info));
}
let info = tree.info(false, None);
collector.register_draw(math, &tree.draw, &info);
return Ok((tree.draw, info));
}
ExtendResult::Diverging(tree, info) => {
let info = tree.info(false, Some(info));
Expand Down
24 changes: 24 additions & 0 deletions src/storage/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ impl ArrowBuilder {
ItemType::I64 => Box::new(Int64Builder::with_capacity(capacity)),
ItemType::U64 => Box::new(UInt64Builder::with_capacity(capacity)),
ItemType::String => Box::new(StringBuilder::with_capacity(capacity, capacity)),
ItemType::DateTime64(_) => {
panic!("DateTime values not supported as values in arrow storage")
}
ItemType::TimeDelta64(_) => {
panic!("TimeDelta values not supported as values in arrow storage")
}
};

if shape.is_empty() {
Expand Down Expand Up @@ -100,6 +106,12 @@ impl ArrowBuilder {
string_builder.append_value(&item);
}
}
Value::DateTime64(_, _) => {
panic!("DateTime64 scalar values not supported in arrow storage")
}
Value::TimeDelta64(_, _) => {
panic!("TimeDelta64 scalar values not supported in arrow storage")
}
},
ArrowBuilder::Tensor(list_builder) => {
match value {
Expand Down Expand Up @@ -154,6 +166,12 @@ impl ArrowBuilder {
downcast_builder!(list_builder.values(), BooleanBuilder, ScalarBool)?
.append_value(val);
}
Value::DateTime64(_, _) => {
panic!("DateTime64 scalar values not supported in arrow storage")
}
Value::TimeDelta64(_, _) => {
panic!("TimeDelta64 scalar values not supported in arrow storage")
}
}
list_builder.append(true);
}
Expand Down Expand Up @@ -211,6 +229,12 @@ fn item_type_to_arrow_type(item_type: ItemType) -> DataType {
ItemType::I64 => DataType::Int64,
ItemType::Bool => DataType::Boolean,
ItemType::String => DataType::Utf8,
ItemType::DateTime64(_) => {
panic!("DateTime64 scalar values not supported in arrow storage")
}
ItemType::TimeDelta64(_) => {
panic!("TimeDelta64 scalar values not supported in arrow storage")
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/storage/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ impl CsvChainStorage {
vec[0].clone()
}
}
Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"),
Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"),
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/storage/hashmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl HashMapValue {
ItemType::I64 => HashMapValue::I64(Vec::new()),
ItemType::U64 => HashMapValue::U64(Vec::new()),
ItemType::String => HashMapValue::String(Vec::new()),
ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
}
}

Expand All @@ -45,6 +46,11 @@ impl HashMapValue {
(HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
(HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),

(HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
(HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
(HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
(HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),

_ => panic!("Mismatched item type"),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/storage/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ impl NdarrayValue {
ItemType::String => {
NdarrayValue::String(ArrayD::from_elem(IxDyn(shape), String::new()))
}
ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => {
NdarrayValue::I64(ArrayD::zeros(IxDyn(shape)))
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions src/storage/zarr/async_impl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::iter::once;
use std::num::NonZero;
use std::sync::Arc;
use tokio::task::JoinHandle;

Expand All @@ -8,6 +9,7 @@ use nuts_storable::{ItemType, Value};
use zarrs::array::{ArrayBuilder, DataType, FillValue};
use zarrs::array_subset::ArraySubset;
use zarrs::group::GroupBuilder;
use zarrs::metadata_ext::data_type::NumpyTimeUnit;
use zarrs::storage::{
AsyncReadableWritableListableStorage, AsyncReadableWritableListableStorageTraits,
};
Expand Down Expand Up @@ -140,6 +142,38 @@ async fn store_coords(
&Value::I64(ref v) => (DataType::Int64, v.len(), FillValue::from(0i64)),
&Value::Bool(ref v) => (DataType::Bool, v.len(), FillValue::from(false)),
&Value::Strings(ref v) => (DataType::String, v.len(), FillValue::from("")),
&Value::DateTime64(unit, ref v) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
(
DataType::NumpyDateTime64 {
unit,
scale_factor: NonZero::new(1).unwrap(),
},
v.len(),
FillValue::from(0i64),
)
}
&Value::TimeDelta64(unit, ref v) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
(
DataType::NumpyTimeDelta64 {
unit,
scale_factor: NonZero::new(1).unwrap(),
},
v.len(),
FillValue::from(0i64),
)
}
_ => panic!("Unsupported coordinate type for {}", name),
};
let name: &String = name;
Expand Down Expand Up @@ -179,6 +213,16 @@ async fn store_coords(
.async_store_chunk_elements::<String>(&subset, v)
.await?
}
&Value::DateTime64(_, ref data) => {
coord_array
.async_store_chunk_elements::<i64>(&subset, data)
.await?
}
&Value::TimeDelta64(_, ref data) => {
coord_array
.async_store_chunk_elements::<i64>(&subset, data)
.await?
}
_ => unreachable!(),
}
coord_array.async_store_metadata().await?;
Expand Down
25 changes: 24 additions & 1 deletion src/storage/zarr/common.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::HashMap;
use std::mem::replace;
use std::sync::Arc;
use std::{collections::HashMap, num::NonZero};

use anyhow::Result;
use nuts_storable::{ItemType, Value};
use zarrs::array::{Array, ArrayBuilder, DataType, FillValue};
use zarrs::metadata_ext::data_type::NumpyTimeUnit;

/// Container for different types of sample values
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -51,6 +52,8 @@ impl SampleBuffer {
ItemType::Bool => SampleBufferValue::Bool(Vec::with_capacity(chunk_size)),
ItemType::I64 => SampleBufferValue::I64(Vec::with_capacity(chunk_size)),
ItemType::String => panic!("String type not supported in SampleBuffer"),
ItemType::DateTime64(_) => panic!("DateTime64 type not supported in SampleBuffer"),
ItemType::TimeDelta64(_) => panic!("TimeDelta64 type not supported in SampleBuffer"),
};
Self {
items: inner,
Expand Down Expand Up @@ -196,6 +199,24 @@ pub fn create_arrays<TStorage: ?Sized>(
ItemType::I64 => DataType::Int64,
ItemType::Bool => DataType::Bool,
ItemType::String => DataType::String,
ItemType::DateTime64(unit) => DataType::NumpyDateTime64 {
unit: match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
},
scale_factor: NonZero::new(1).unwrap(),
},
ItemType::TimeDelta64(unit) => DataType::NumpyTimeDelta64 {
unit: match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
},
scale_factor: NonZero::new(1).unwrap(),
},
};
let fill_value = match item_type {
ItemType::F64 => FillValue::from(f64::NAN),
Expand All @@ -204,6 +225,8 @@ pub fn create_arrays<TStorage: ?Sized>(
ItemType::I64 => FillValue::from(0i64),
ItemType::Bool => FillValue::from(false),
ItemType::String => FillValue::from(""),
ItemType::DateTime64(_) => FillValue::new_null(),
ItemType::TimeDelta64(_) => FillValue::new_null(),
};
let grid: Vec<u64> = std::iter::once(1)
.chain(std::iter::once(draw_chunk_size))
Expand Down
Loading
Loading