Skip to content

Commit 569b99c

Browse files
committed
Move coords typing to pymc.typing and fix printing imports
1 parent f0a92c5 commit 569b99c

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pymc/backends/arviz.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414
"""PyMC-ArviZ conversion code."""
1515

16+
from __future__ import annotations
17+
1618
import logging
1719
import warnings
1820

1921
from collections.abc import Iterable, Mapping, Sequence
2022
from typing import (
2123
TYPE_CHECKING,
2224
Any,
23-
Optional,
2425
TypeAlias,
25-
Union,
2626
cast,
2727
)
2828

@@ -90,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
9090
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)
9191

9292

93-
def find_observations(model: "Model") -> dict[str, Var]:
93+
def find_observations(model: Model) -> dict[str, Var]:
9494
"""If there are observations available, return them as a dictionary."""
9595
observations = {}
9696
for obs in model.observed_RVs:
@@ -107,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
107107
return observations
108108

109109

110-
def find_constants(model: "Model") -> dict[str, Var]:
110+
def find_constants(model: Model) -> dict[str, Var]:
111111
"""If there are constants available, return them as a dictionary."""
112112
model_vars = model.basic_RVs + model.deterministics + model.potentials
113113
value_vars = set(model.rvs_to_values.values())
@@ -272,7 +272,7 @@ def __init__(
272272

273273
self.observations = find_observations(self.model)
274274

275-
def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
275+
def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]:
276276
"""Split MultiTrace object into posterior and warmup.
277277
278278
Returns
@@ -498,7 +498,7 @@ def to_inference_data(self):
498498

499499

500500
def to_inference_data(
501-
trace: Optional["MultiTrace"] = None,
501+
trace: MultiTrace | None = None,
502502
*,
503503
prior: Mapping[str, Any] | None = None,
504504
posterior_predictive: Mapping[str, Any] | None = None,
@@ -507,7 +507,7 @@ def to_inference_data(
507507
coords: CoordSpec | None = None,
508508
dims: DimSpec | None = None,
509509
sample_dims: list | None = None,
510-
model: Optional["Model"] = None,
510+
model: Model | None = None,
511511
save_warmup: bool | None = None,
512512
include_transformed: bool = False,
513513
) -> InferenceData:
@@ -575,8 +575,8 @@ def to_inference_data(
575575
### perhaps we should have an inplace argument?
576576
def predictions_to_inference_data(
577577
predictions,
578-
posterior_trace: Optional["MultiTrace"] = None,
579-
model: Optional["Model"] = None,
578+
posterior_trace: MultiTrace | None = None,
579+
model: Model | None = None,
580580
coords: CoordSpec | None = None,
581581
dims: DimSpec | None = None,
582582
sample_dims: list | None = None,

0 commit comments

Comments
 (0)