Skip to content

Commit 244eb13

Browse files
committed
feat: add arguments for walnuts
1 parent 26e085c commit 244eb13

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

src/wrapper.rs

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use arrow::array::Array;
1717
use numpy::{PyArray1, PyReadonlyArray1};
1818
use nuts_rs::{
1919
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler,
20-
SamplerWaitResult, Trace, TransformedNutsSettings,
20+
SamplerWaitResult, Trace, TransformedNutsSettings, WalnutsOptions,
2121
};
2222
use pyo3::{
2323
exceptions::PyTimeoutError,
@@ -654,6 +654,134 @@ impl PyNutsSettings {
654654
}
655655
Ok(())
656656
}
657+
658+
#[getter]
659+
fn max_step_size_halvings(&self) -> Result<Option<u64>> {
660+
let walnuts = match &self.inner {
661+
Settings::LowRank(inner) => inner.walnuts_options,
662+
Settings::Diag(inner) => inner.walnuts_options,
663+
Settings::Transforming(inner) => inner.walnuts_options,
664+
};
665+
if let Some(walnuts) = walnuts {
666+
Ok(Some(walnuts.max_step_size_halvings))
667+
} else {
668+
Ok(None)
669+
}
670+
}
671+
672+
#[setter(max_step_size_halvings)]
673+
fn set_max_step_size_halvings(&mut self, val: Option<u64>) -> Result<()> {
674+
let options = match &mut self.inner {
675+
Settings::LowRank(inner) => &mut inner.walnuts_options,
676+
Settings::Diag(inner) => &mut inner.walnuts_options,
677+
Settings::Transforming(inner) => &mut inner.walnuts_options,
678+
};
679+
680+
if let Some(max_halvings) = val {
681+
if let Some(ref mut options) = options {
682+
options.max_step_size_halvings = max_halvings;
683+
} else {
684+
let mut new_options = WalnutsOptions::default();
685+
new_options.max_step_size_halvings = max_halvings;
686+
*options = Some(new_options);
687+
}
688+
} else {
689+
*options = None;
690+
}
691+
692+
Ok(())
693+
}
694+
695+
#[getter]
696+
fn max_walnuts_energy_error(&self) -> Result<Option<f64>> {
697+
let walnuts = match &self.inner {
698+
Settings::LowRank(inner) => inner.walnuts_options,
699+
Settings::Diag(inner) => inner.walnuts_options,
700+
Settings::Transforming(inner) => inner.walnuts_options,
701+
};
702+
if let Some(walnuts) = walnuts {
703+
Ok(Some(walnuts.max_energy_error))
704+
} else {
705+
Ok(None)
706+
}
707+
}
708+
709+
#[setter(max_walnuts_energy_error)]
710+
fn set_max_walnuts_energy_error(&mut self, val: Option<f64>) -> Result<()> {
711+
let options = match &mut self.inner {
712+
Settings::LowRank(inner) => &mut inner.walnuts_options,
713+
Settings::Diag(inner) => &mut inner.walnuts_options,
714+
Settings::Transforming(inner) => &mut inner.walnuts_options,
715+
};
716+
717+
if let Some(max_error) = val {
718+
if let Some(ref mut options) = options {
719+
options.max_energy_error = max_error;
720+
} else {
721+
let mut new_options = WalnutsOptions::default();
722+
new_options.max_energy_error = max_error;
723+
*options = Some(new_options);
724+
}
725+
} else {
726+
*options = None;
727+
}
728+
729+
Ok(())
730+
}
731+
732+
#[getter]
733+
fn fixed_step_size(&self) -> Result<Option<f64>> {
734+
match &self.inner {
735+
Settings::LowRank(inner) => {
736+
Ok(inner.adapt_options.dual_average_options.fixed_step_size)
737+
}
738+
Settings::Diag(inner) => Ok(inner.adapt_options.dual_average_options.fixed_step_size),
739+
Settings::Transforming(inner) => {
740+
Ok(inner.adapt_options.dual_average_options.fixed_step_size)
741+
}
742+
}
743+
}
744+
745+
#[setter(fixed_step_size)]
746+
fn set_fixed_step_size(&mut self, val: Option<f64>) -> Result<()> {
747+
match &mut self.inner {
748+
Settings::LowRank(inner) => {
749+
inner.adapt_options.dual_average_options.fixed_step_size = val;
750+
}
751+
Settings::Diag(inner) => {
752+
inner.adapt_options.dual_average_options.fixed_step_size = val;
753+
}
754+
Settings::Transforming(inner) => {
755+
inner.adapt_options.dual_average_options.fixed_step_size = val;
756+
}
757+
}
758+
Ok(())
759+
}
760+
761+
#[getter]
762+
fn step_size_jitter(&self) -> Result<Option<f64>> {
763+
match &self.inner {
764+
Settings::LowRank(inner) => Ok(inner.adapt_options.dual_average_options.jitter),
765+
Settings::Diag(inner) => Ok(inner.adapt_options.dual_average_options.jitter),
766+
Settings::Transforming(inner) => Ok(inner.adapt_options.dual_average_options.jitter),
767+
}
768+
}
769+
770+
#[setter(step_size_jitter)]
771+
fn set_step_size_jitter(&mut self, val: Option<f64>) -> Result<()> {
772+
match &mut self.inner {
773+
Settings::LowRank(inner) => {
774+
inner.adapt_options.dual_average_options.jitter = val;
775+
}
776+
Settings::Diag(inner) => {
777+
inner.adapt_options.dual_average_options.jitter = val;
778+
}
779+
Settings::Transforming(inner) => {
780+
inner.adapt_options.dual_average_options.jitter = val;
781+
}
782+
}
783+
Ok(())
784+
}
657785
}
658786

659787
pub(crate) enum SamplerState {

0 commit comments

Comments
 (0)