From 1080099f74ec4c1ec544122bad6dc4634852a5a4 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Sun, 30 Nov 2025 16:19:03 +0500 Subject: [PATCH 1/4] Fixed UserWarning when converting sample_stats to idata --- pymc/backends/arviz.py | 37 +++++++++++++++++++++++-------------- pymc/smc/sampling.py | 4 +++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..4b4ede3935 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -29,7 +29,8 @@ import xarray from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires +from arviz.data.base import CoordSpec, DimSpec, requires +from arviz_base import dict_to_dataset from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -305,14 +306,14 @@ def posterior_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, @@ -347,14 +348,14 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, @@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self): data = self.posterior_predictive dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) @requires(["predictions"]) @@ -376,7 +381,11 @@ def predictions_to_xarray(self): data = self.predictions dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) def priors_to_xarray(self): @@ -399,7 +408,7 @@ def priors_to_xarray(self): if var_names is None else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, ) @@ -414,10 +423,10 @@ def observed_data_to_xarray(self): return None return dict_to_dataset( self.observations, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) @requires("model") @@ -429,10 +438,10 @@ def constant_data_to_xarray(self): xarray_dataset = dict_to_dataset( constant_data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) # provisional handling of scalars in constant @@ -707,9 +716,9 @@ def apply_function_over_dataset( return dict_to_dataset( out_trace, - library=pymc, + inference_library=pymc, dims=dims, coords=coords, - default_dims=list(sample_dims), + sample_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 5afd398281..352224026e 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -267,7 +267,9 @@ def _save_sample_stats( sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, - library=pymc, + inference_library=pymc, + sample_dims=["chain"], + check_conventions=False, ) ikwargs: dict[str, Any] = {"model": model} From 5becf3358fbe16073cf14e766de3927a58833213 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Sun, 30 Nov 2025 17:37:35 +0500 Subject: [PATCH 2/4] Added arviz-base to requirements list --- requirements-dev.txt | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 22bcdaf9ea..4be336125a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ # See that file for comments about the need/usage of each dependency. arviz>=0.13.0 +arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle ipython>=7.16 diff --git a/requirements.txt b/requirements.txt index 8401b78a15..80e1d5a027 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ arviz>=0.13.0 +arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle numpy>=1.25.0 From 10b09fb5b5511e26f79b9d3513aaad4d4009753f Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Mon, 1 Dec 2025 02:19:18 +0500 Subject: [PATCH 3/4] Added arviz-base dependency in conda-envs/environment-dev.yml --- conda-envs/environment-dev.yml | 1 + requirements-dev.txt | 2 +- requirements.txt | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 231dfa05cf..230dfcb62f 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - blas - cachetools>=4.2.1 - cloudpickle diff --git a/requirements-dev.txt b/requirements-dev.txt index 4be336125a..25bc83223e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. +arviz-base arviz>=0.13.0 -arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle ipython>=7.16 diff --git a/requirements.txt b/requirements.txt index 80e1d5a027..8401b78a15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ arviz>=0.13.0 -arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle numpy>=1.25.0 From b77c695393d80af19cb9b609e165112ed36cc6bc Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Mon, 1 Dec 2025 02:55:50 +0500 Subject: [PATCH 4/4] Attempting to add arviz-base as a dependency --- conda-envs/environment-docs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index f85f8fc55b..e8ea8cbd1a 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0