Skip to content

Commit 1080099

Browse files
committed
Fixed UserWarning when converting sample_stats to idata
1 parent f69e159 commit 1080099

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

pymc/backends/arviz.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import xarray
3030

3131
from arviz import InferenceData, concat, rcParams
32-
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
32+
from arviz.data.base import CoordSpec, DimSpec, requires
33+
from arviz_base import dict_to_dataset
3334
from pytensor.graph import ancestors
3435
from pytensor.tensor.sharedvar import SharedVariable
3536
from rich.progress import Console
@@ -305,14 +306,14 @@ def posterior_to_xarray(self):
305306
return (
306307
dict_to_dataset(
307308
data,
308-
library=pymc,
309+
inference_library=pymc,
309310
coords=self.coords,
310311
dims=self.dims,
311312
attrs=self.attrs,
312313
),
313314
dict_to_dataset(
314315
data_warmup,
315-
library=pymc,
316+
inference_library=pymc,
316317
coords=self.coords,
317318
dims=self.dims,
318319
attrs=self.attrs,
@@ -347,14 +348,14 @@ def sample_stats_to_xarray(self):
347348
return (
348349
dict_to_dataset(
349350
data,
350-
library=pymc,
351+
inference_library=pymc,
351352
dims=None,
352353
coords=self.coords,
353354
attrs=self.attrs,
354355
),
355356
dict_to_dataset(
356357
data_warmup,
357-
library=pymc,
358+
inference_library=pymc,
358359
dims=None,
359360
coords=self.coords,
360361
attrs=self.attrs,
@@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self):
367368
data = self.posterior_predictive
368369
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
369370
return dict_to_dataset(
370-
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
371+
data,
372+
inference_library=pymc,
373+
coords=self.coords,
374+
dims=dims,
375+
sample_dims=self.sample_dims,
371376
)
372377

373378
@requires(["predictions"])
@@ -376,7 +381,11 @@ def predictions_to_xarray(self):
376381
data = self.predictions
377382
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
378383
return dict_to_dataset(
379-
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
384+
data,
385+
inference_library=pymc,
386+
coords=self.coords,
387+
dims=dims,
388+
sample_dims=self.sample_dims,
380389
)
381390

382391
def priors_to_xarray(self):
@@ -399,7 +408,7 @@ def priors_to_xarray(self):
399408
if var_names is None
400409
else dict_to_dataset_drop_incompatible_coords(
401410
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
402-
library=pymc,
411+
inference_library=pymc,
403412
coords=self.coords,
404413
dims=self.dims,
405414
)
@@ -414,10 +423,10 @@ def observed_data_to_xarray(self):
414423
return None
415424
return dict_to_dataset(
416425
self.observations,
417-
library=pymc,
426+
inference_library=pymc,
418427
coords=self.coords,
419428
dims=self.dims,
420-
default_dims=[],
429+
sample_dims=[],
421430
)
422431

423432
@requires("model")
@@ -429,10 +438,10 @@ def constant_data_to_xarray(self):
429438

430439
xarray_dataset = dict_to_dataset(
431440
constant_data,
432-
library=pymc,
441+
inference_library=pymc,
433442
coords=self.coords,
434443
dims=self.dims,
435-
default_dims=[],
444+
sample_dims=[],
436445
)
437446

438447
# provisional handling of scalars in constant
@@ -707,9 +716,9 @@ def apply_function_over_dataset(
707716

708717
return dict_to_dataset(
709718
out_trace,
710-
library=pymc,
719+
inference_library=pymc,
711720
dims=dims,
712721
coords=coords,
713-
default_dims=list(sample_dims),
722+
sample_dims=list(sample_dims),
714723
skip_event_dims=True,
715724
)

pymc/smc/sampling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def _save_sample_stats(
267267
sample_stats = dict_to_dataset(
268268
sample_stats_dict,
269269
attrs=sample_settings_dict,
270-
library=pymc,
270+
inference_library=pymc,
271+
sample_dims=["chain"],
272+
check_conventions=False,
271273
)
272274

273275
ikwargs: dict[str, Any] = {"model": model}

0 commit comments

Comments
 (0)