Skip to content

Commit 7ccf981

Browse files
committed
Add Coords and StrongCoords typing aliases and standardize model/arviz usage
1 parent 87f80f9 commit 7ccf981

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

pymc/backends/arviz.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@
2323
Optional,
2424
Union,
2525
cast,
26+
TypeAlias,
2627
)
2728

29+
from pymc.distributions.shape_utils import StrongCoords
30+
31+
2832
import numpy as np
2933
import xarray
3034

@@ -56,6 +60,7 @@
5660

5761
# random variable object ...
5862
Var = Any
63+
DimsDict: TypeAlias = Mapping[str, Sequence[str]]
5964

6065

6166
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
@@ -123,7 +128,8 @@ def find_constants(model: "Model") -> dict[str, Var]:
123128
return constant_data
124129

125130

126-
def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
131+
def coords_and_dims_for_inferencedata(model: Model,) -> tuple[StrongCoords, DimsDict]:
132+
127133
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
128134
coords = {
129135
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals

pymc/distributions/shape_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ def _check_shape_type(shape):
9797
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
9898
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]
9999

100+
from collections.abc import Mapping
101+
from typing import Hashable
102+
103+
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
104+
Coords: TypeAlias = Mapping[str, CoordValue]
105+
106+
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
107+
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]
108+
100109

101110
def convert_dims(dims: Dims | None) -> StrongDims | None:
102111
"""Process a user-provided dims variable into None or a valid dims tuple."""

pymc/model/core.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import warnings
2121

2222
from collections.abc import Iterable, Sequence
23+
from pymc.distributions.shape_utils import Coords, StrongCoords, CoordValue
24+
2325
from typing import (
2426
Literal,
2527
cast,
@@ -453,7 +455,7 @@ def _validate_name(name):
453455
def __init__(
454456
self,
455457
name="",
456-
coords=None,
458+
coords: Coords | None = None,
457459
check_bounds=True,
458460
*,
459461
model: _UnsetType | None | Model = UNSET,
@@ -488,7 +490,7 @@ def __init__(
488490
self.deterministics = treelist()
489491
self.potentials = treelist()
490492
self.data_vars = treelist()
491-
self._coords = {}
493+
self._coords: StrongCoords = {}
492494
self._dim_lengths = {}
493495
self.add_coords(coords)
494496

@@ -907,9 +909,9 @@ def unobserved_RVs(self):
907909
return self.free_RVs + self.deterministics
908910

909911
@property
910-
def coords(self) -> dict[str, tuple | None]:
912+
def coords(self) -> StrongCoords:
911913
"""Coordinate values for model dimensions."""
912-
return self._coords
914+
return self._coords
913915

914916
@property
915917
def dim_lengths(self) -> dict[str, TensorVariable]:
@@ -937,7 +939,7 @@ def shape_from_dims(self, dims):
937939
def add_coord(
938940
self,
939941
name: str,
940-
values: Sequence | np.ndarray | None = None,
942+
values: CoordValue = None,
941943
*,
942944
length: int | Variable | None = None,
943945
):

0 commit comments

Comments
 (0)