3434
3535from arviz import InferenceData , dict_to_dataset
3636from arviz .data .base import make_attrs
37- from fastprogress .fastprogress import progress_bar
3837from pytensor .graph .basic import Variable
38+ from rich .console import Console
39+ from rich .progress import Progress
40+ from rich .theme import Theme
3941from typing_extensions import Protocol , TypeAlias
4042
4143import pymc as pm
6567 RandomSeed ,
6668 RandomState ,
6769 _get_seeds_per_chain ,
70+ default_progress_theme ,
6871 drop_warning_stat ,
6972 get_untransformed_name ,
7073 is_transformed_name ,
@@ -377,6 +380,7 @@ def sample(
377380 cores : Optional [int ] = None ,
378381 random_seed : RandomState = None ,
379382 progressbar : bool = True ,
383+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
380384 step = None ,
381385 var_names : Optional [Sequence [str ]] = None ,
382386 nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -406,6 +410,7 @@ def sample(
406410 cores : Optional [int ] = None ,
407411 random_seed : RandomState = None ,
408412 progressbar : bool = True ,
413+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
409414 step = None ,
410415 var_names : Optional [Sequence [str ]] = None ,
411416 nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -435,6 +440,7 @@ def sample(
435440 cores : Optional [int ] = None ,
436441 random_seed : RandomState = None ,
437442 progressbar : bool = True ,
443+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
438444 step = None ,
439445 var_names : Optional [Sequence [str ]] = None ,
440446 nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -761,6 +767,7 @@ def sample(
761767 "tune" : tune ,
762768 "var_names" : var_names ,
763769 "progressbar" : progressbar ,
770+ "progressbar_theme" : progressbar_theme ,
764771 "model" : model ,
765772 "cores" : cores ,
766773 "callback" : callback ,
@@ -983,6 +990,7 @@ def _sample(
983990 trace : IBaseTrace ,
984991 tune : int ,
985992 model : Optional [Model ] = None ,
993+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
986994 callback = None ,
987995 ** kwargs ,
988996) -> None :
@@ -1010,6 +1018,8 @@ def _sample(
10101018 tune : int
10111019 Number of iterations to tune.
10121020 model : Model (optional if in ``with`` context)
1021+ progressbar_theme : Theme
1022+ Optional custom theme for the progress bar.
10131023 """
10141024 skip_first = kwargs .get ("skip_first" , 0 )
10151025
@@ -1026,19 +1036,16 @@ def _sample(
10261036 )
10271037 _pbar_data = {"chain" : chain , "divergences" : 0 }
10281038 _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1029- if progressbar :
1030- sampling = progress_bar (sampling_gen , total = draws , display = progressbar )
1031- sampling .comment = _desc .format (** _pbar_data )
1032- else :
1033- sampling = sampling_gen
1034- try :
1035- for it , diverging in enumerate (sampling ):
1036- if it >= skip_first and diverging :
1037- _pbar_data ["divergences" ] += 1
1038- if progressbar :
1039- sampling .comment = _desc .format (** _pbar_data )
1040- except KeyboardInterrupt :
1041- pass
1039+ with Progress (console = Console (theme = progressbar_theme )) as progress :
1040+ try :
1041+ task = progress .add_task (_desc .format (** _pbar_data ), total = draws , visible = progressbar )
1042+ for it , diverging in enumerate (sampling_gen ):
1043+ if it >= skip_first and diverging :
1044+ _pbar_data ["divergences" ] += 1
1045+ progress .update (task , advance = 1 )
1046+ progress .update (task , advance = 1 , completed = True )
1047+ except KeyboardInterrupt :
1048+ pass
10421049
10431050
10441051def _iter_sample (
@@ -1131,6 +1138,7 @@ def _mp_sample(
11311138 random_seed : Sequence [RandomSeed ],
11321139 start : Sequence [PointType ],
11331140 progressbar : bool = True ,
1141+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
11341142 traces : Sequence [IBaseTrace ],
11351143 model : Optional [Model ] = None ,
11361144 callback : Optional [SamplingIteratorCallback ] = None ,
@@ -1158,6 +1166,8 @@ def _mp_sample(
11581166 Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
11591167 progressbar : bool
11601168 Whether or not to display a progress bar in the command line.
1169+ progressbar_theme : Theme
1170+ Optional custom theme for the progress bar.
11611171 traces
11621172 Recording backends for each chain.
11631173 model : Model (optional if in ``with`` context)
@@ -1182,6 +1192,7 @@ def _mp_sample(
11821192 start_points = start ,
11831193 step_method = step ,
11841194 progressbar = progressbar ,
1195+ progressbar_theme = progressbar_theme ,
11851196 mp_ctx = mp_ctx ,
11861197 )
11871198 try :
0 commit comments