@@ -17,7 +17,7 @@ use arrow::array::Array;
1717use numpy:: { PyArray1 , PyReadonlyArray1 } ;
1818use nuts_rs:: {
1919 ChainProgress , DiagGradNutsSettings , LowRankNutsSettings , ProgressCallback , Sampler ,
20- SamplerWaitResult , Trace , TransformedNutsSettings ,
20+ SamplerWaitResult , Trace , TransformedNutsSettings , WalnutsOptions ,
2121} ;
2222use pyo3:: {
2323 exceptions:: PyTimeoutError ,
@@ -654,6 +654,134 @@ impl PyNutsSettings {
654654 }
655655 Ok ( ( ) )
656656 }
657+
658+ #[ getter]
659+ fn max_step_size_halvings ( & self ) -> Result < Option < u64 > > {
660+ let walnuts = match & self . inner {
661+ Settings :: LowRank ( inner) => inner. walnuts_options ,
662+ Settings :: Diag ( inner) => inner. walnuts_options ,
663+ Settings :: Transforming ( inner) => inner. walnuts_options ,
664+ } ;
665+ if let Some ( walnuts) = walnuts {
666+ Ok ( Some ( walnuts. max_step_size_halvings ) )
667+ } else {
668+ Ok ( None )
669+ }
670+ }
671+
672+ #[ setter( max_step_size_halvings) ]
673+ fn set_max_step_size_halvings ( & mut self , val : Option < u64 > ) -> Result < ( ) > {
674+ let options = match & mut self . inner {
675+ Settings :: LowRank ( inner) => & mut inner. walnuts_options ,
676+ Settings :: Diag ( inner) => & mut inner. walnuts_options ,
677+ Settings :: Transforming ( inner) => & mut inner. walnuts_options ,
678+ } ;
679+
680+ if let Some ( max_halvings) = val {
681+ if let Some ( ref mut options) = options {
682+ options. max_step_size_halvings = max_halvings;
683+ } else {
684+ let mut new_options = WalnutsOptions :: default ( ) ;
685+ new_options. max_step_size_halvings = max_halvings;
686+ * options = Some ( new_options) ;
687+ }
688+ } else {
689+ * options = None ;
690+ }
691+
692+ Ok ( ( ) )
693+ }
694+
695+ #[ getter]
696+ fn max_walnuts_energy_error ( & self ) -> Result < Option < f64 > > {
697+ let walnuts = match & self . inner {
698+ Settings :: LowRank ( inner) => inner. walnuts_options ,
699+ Settings :: Diag ( inner) => inner. walnuts_options ,
700+ Settings :: Transforming ( inner) => inner. walnuts_options ,
701+ } ;
702+ if let Some ( walnuts) = walnuts {
703+ Ok ( Some ( walnuts. max_energy_error ) )
704+ } else {
705+ Ok ( None )
706+ }
707+ }
708+
709+ #[ setter( max_walnuts_energy_error) ]
710+ fn set_max_walnuts_energy_error ( & mut self , val : Option < f64 > ) -> Result < ( ) > {
711+ let options = match & mut self . inner {
712+ Settings :: LowRank ( inner) => & mut inner. walnuts_options ,
713+ Settings :: Diag ( inner) => & mut inner. walnuts_options ,
714+ Settings :: Transforming ( inner) => & mut inner. walnuts_options ,
715+ } ;
716+
717+ if let Some ( max_error) = val {
718+ if let Some ( ref mut options) = options {
719+ options. max_energy_error = max_error;
720+ } else {
721+ let mut new_options = WalnutsOptions :: default ( ) ;
722+ new_options. max_energy_error = max_error;
723+ * options = Some ( new_options) ;
724+ }
725+ } else {
726+ * options = None ;
727+ }
728+
729+ Ok ( ( ) )
730+ }
731+
732+ #[ getter]
733+ fn fixed_step_size ( & self ) -> Result < Option < f64 > > {
734+ match & self . inner {
735+ Settings :: LowRank ( inner) => {
736+ Ok ( inner. adapt_options . dual_average_options . fixed_step_size )
737+ }
738+ Settings :: Diag ( inner) => Ok ( inner. adapt_options . dual_average_options . fixed_step_size ) ,
739+ Settings :: Transforming ( inner) => {
740+ Ok ( inner. adapt_options . dual_average_options . fixed_step_size )
741+ }
742+ }
743+ }
744+
745+ #[ setter( fixed_step_size) ]
746+ fn set_fixed_step_size ( & mut self , val : Option < f64 > ) -> Result < ( ) > {
747+ match & mut self . inner {
748+ Settings :: LowRank ( inner) => {
749+ inner. adapt_options . dual_average_options . fixed_step_size = val;
750+ }
751+ Settings :: Diag ( inner) => {
752+ inner. adapt_options . dual_average_options . fixed_step_size = val;
753+ }
754+ Settings :: Transforming ( inner) => {
755+ inner. adapt_options . dual_average_options . fixed_step_size = val;
756+ }
757+ }
758+ Ok ( ( ) )
759+ }
760+
761+ #[ getter]
762+ fn step_size_jitter ( & self ) -> Result < Option < f64 > > {
763+ match & self . inner {
764+ Settings :: LowRank ( inner) => Ok ( inner. adapt_options . dual_average_options . jitter ) ,
765+ Settings :: Diag ( inner) => Ok ( inner. adapt_options . dual_average_options . jitter ) ,
766+ Settings :: Transforming ( inner) => Ok ( inner. adapt_options . dual_average_options . jitter ) ,
767+ }
768+ }
769+
770+ #[ setter( step_size_jitter) ]
771+ fn set_step_size_jitter ( & mut self , val : Option < f64 > ) -> Result < ( ) > {
772+ match & mut self . inner {
773+ Settings :: LowRank ( inner) => {
774+ inner. adapt_options . dual_average_options . jitter = val;
775+ }
776+ Settings :: Diag ( inner) => {
777+ inner. adapt_options . dual_average_options . jitter = val;
778+ }
779+ Settings :: Transforming ( inner) => {
780+ inner. adapt_options . dual_average_options . jitter = val;
781+ }
782+ }
783+ Ok ( ( ) )
784+ }
657785}
658786
659787pub ( crate ) enum SamplerState {
0 commit comments