2929import xarray
3030
3131from arviz import InferenceData , concat , rcParams
32- from arviz .data .base import CoordSpec , DimSpec , dict_to_dataset , requires
32+ from arviz .data .base import CoordSpec , DimSpec , requires
33+ from arviz_base import dict_to_dataset
3334from pytensor .graph import ancestors
3435from pytensor .tensor .sharedvar import SharedVariable
3536from rich .progress import Console
@@ -305,14 +306,14 @@ def posterior_to_xarray(self):
305306 return (
306307 dict_to_dataset (
307308 data ,
308- library = pymc ,
309+ inference_library = pymc ,
309310 coords = self .coords ,
310311 dims = self .dims ,
311312 attrs = self .attrs ,
312313 ),
313314 dict_to_dataset (
314315 data_warmup ,
315- library = pymc ,
316+ inference_library = pymc ,
316317 coords = self .coords ,
317318 dims = self .dims ,
318319 attrs = self .attrs ,
@@ -347,14 +348,14 @@ def sample_stats_to_xarray(self):
347348 return (
348349 dict_to_dataset (
349350 data ,
350- library = pymc ,
351+ inference_library = pymc ,
351352 dims = None ,
352353 coords = self .coords ,
353354 attrs = self .attrs ,
354355 ),
355356 dict_to_dataset (
356357 data_warmup ,
357- library = pymc ,
358+ inference_library = pymc ,
358359 dims = None ,
359360 coords = self .coords ,
360361 attrs = self .attrs ,
@@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self):
367368 data = self .posterior_predictive
368369 dims = {var_name : self .sample_dims + self .dims .get (var_name , []) for var_name in data }
369370 return dict_to_dataset (
370- data , library = pymc , coords = self .coords , dims = dims , default_dims = self .sample_dims
371+ data ,
372+ inference_library = pymc ,
373+ coords = self .coords ,
374+ dims = dims ,
375+ sample_dims = self .sample_dims ,
371376 )
372377
373378 @requires (["predictions" ])
@@ -376,7 +381,11 @@ def predictions_to_xarray(self):
376381 data = self .predictions
377382 dims = {var_name : self .sample_dims + self .dims .get (var_name , []) for var_name in data }
378383 return dict_to_dataset (
379- data , library = pymc , coords = self .coords , dims = dims , default_dims = self .sample_dims
384+ data ,
385+ inference_library = pymc ,
386+ coords = self .coords ,
387+ dims = dims ,
388+ sample_dims = self .sample_dims ,
380389 )
381390
382391 def priors_to_xarray (self ):
@@ -399,7 +408,7 @@ def priors_to_xarray(self):
399408 if var_names is None
400409 else dict_to_dataset_drop_incompatible_coords (
401410 {k : np .expand_dims (self .prior [k ], 0 ) for k in var_names },
402- library = pymc ,
411+ inference_library = pymc ,
403412 coords = self .coords ,
404413 dims = self .dims ,
405414 )
@@ -414,10 +423,10 @@ def observed_data_to_xarray(self):
414423 return None
415424 return dict_to_dataset (
416425 self .observations ,
417- library = pymc ,
426+ inference_library = pymc ,
418427 coords = self .coords ,
419428 dims = self .dims ,
420- default_dims = [],
429+ sample_dims = [],
421430 )
422431
423432 @requires ("model" )
@@ -429,10 +438,10 @@ def constant_data_to_xarray(self):
429438
430439 xarray_dataset = dict_to_dataset (
431440 constant_data ,
432- library = pymc ,
441+ inference_library = pymc ,
433442 coords = self .coords ,
434443 dims = self .dims ,
435- default_dims = [],
444+ sample_dims = [],
436445 )
437446
438447 # provisional handling of scalars in constant
@@ -707,9 +716,9 @@ def apply_function_over_dataset(
707716
708717 return dict_to_dataset (
709718 out_trace ,
710- library = pymc ,
719+ inference_library = pymc ,
711720 dims = dims ,
712721 coords = coords ,
713- default_dims = list (sample_dims ),
722+ sample_dims = list (sample_dims ),
714723 skip_event_dims = True ,
715724 )
0 commit comments