1313# limitations under the License.
1414"""PyMC-ArviZ conversion code."""
1515
16+ from __future__ import annotations
17+
1618import logging
1719import warnings
1820
1921from collections .abc import Iterable , Mapping , Sequence
2022from typing import (
2123 TYPE_CHECKING ,
2224 Any ,
23- Optional ,
2425 TypeAlias ,
25- Union ,
2626 cast ,
2727)
2828
@@ -90,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
9090 return dict_to_dataset (vars_dict , * args , dims = dims , coords = safe_coords , ** kwargs )
9191
9292
93- def find_observations (model : " Model" ) -> dict [str , Var ]:
93+ def find_observations (model : Model ) -> dict [str , Var ]:
9494 """If there are observations available, return them as a dictionary."""
9595 observations = {}
9696 for obs in model .observed_RVs :
@@ -107,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
107107 return observations
108108
109109
110- def find_constants (model : " Model" ) -> dict [str , Var ]:
110+ def find_constants (model : Model ) -> dict [str , Var ]:
111111 """If there are constants available, return them as a dictionary."""
112112 model_vars = model .basic_RVs + model .deterministics + model .potentials
113113 value_vars = set (model .rvs_to_values .values ())
@@ -272,7 +272,7 @@ def __init__(
272272
273273 self .observations = find_observations (self .model )
274274
275- def split_trace (self ) -> tuple [Union [ None , " MultiTrace" ], Union [ None , " MultiTrace" ] ]:
275+ def split_trace (self ) -> tuple [None | MultiTrace , None | MultiTrace ]:
276276 """Split MultiTrace object into posterior and warmup.
277277
278278 Returns
@@ -498,7 +498,7 @@ def to_inference_data(self):
498498
499499
500500def to_inference_data (
501- trace : Optional [ " MultiTrace" ] = None ,
501+ trace : MultiTrace | None = None ,
502502 * ,
503503 prior : Mapping [str , Any ] | None = None ,
504504 posterior_predictive : Mapping [str , Any ] | None = None ,
@@ -507,7 +507,7 @@ def to_inference_data(
507507 coords : CoordSpec | None = None ,
508508 dims : DimSpec | None = None ,
509509 sample_dims : list | None = None ,
510- model : Optional [ " Model" ] = None ,
510+ model : Model | None = None ,
511511 save_warmup : bool | None = None ,
512512 include_transformed : bool = False ,
513513) -> InferenceData :
@@ -575,8 +575,8 @@ def to_inference_data(
575575### perhaps we should have an inplace argument?
576576def predictions_to_inference_data (
577577 predictions ,
578- posterior_trace : Optional [ " MultiTrace" ] = None ,
579- model : Optional [ " Model" ] = None ,
578+ posterior_trace : MultiTrace | None = None ,
579+ model : Model | None = None ,
580580 coords : CoordSpec | None = None ,
581581 dims : DimSpec | None = None ,
582582 sample_dims : list | None = None ,
0 commit comments