From 7ccf9816f493be0a68f5087242ce673abb78dfce Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Tue, 2 Dec 2025 18:31:48 +0530 Subject: [PATCH 01/16] Add Coords and StrongCoords typing aliases and standardize model/arviz usage --- pymc/backends/arviz.py | 8 +++++++- pymc/distributions/shape_utils.py | 9 +++++++++ pymc/model/core.py | 12 +++++++----- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..9910271f76 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -23,8 +23,12 @@ Optional, Union, cast, + TypeAlias, ) +from pymc.distributions.shape_utils import StrongCoords + + import numpy as np import xarray @@ -56,6 +60,7 @@ # random variable object ... Var = Any +DimsDict: TypeAlias = Mapping[str, Sequence[str]] def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): @@ -123,7 +128,8 @@ def find_constants(model: "Model") -> dict[str, Var]: return constant_data -def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]: +def coords_and_dims_for_inferencedata(model: Model,) -> tuple[StrongCoords, DimsDict]: + """Parse PyMC model coords and dims format to one accepted by InferenceData.""" coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index cdab3046b1..0123e3e664 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -97,6 +97,15 @@ def _check_shape_type(shape): StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] +from collections.abc import Mapping +from typing import Hashable + +CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None +Coords: TypeAlias = Mapping[str, CoordValue] + +StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None +StrongCoords: TypeAlias = Mapping[str, StrongCoordValue] + def convert_dims(dims: Dims | None) -> StrongDims | None: """Process a user-provided dims variable into None or a valid dims tuple.""" diff --git a/pymc/model/core.py b/pymc/model/core.py index 3630138e00..8a407e5ed2 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -20,6 +20,8 @@ import warnings from collections.abc import Iterable, Sequence +from pymc.distributions.shape_utils import Coords, StrongCoords, CoordValue + from typing import ( Literal, cast, @@ -453,7 +455,7 @@ def _validate_name(name): def __init__( self, name="", - coords=None, + coords: Coords | None = None, check_bounds=True, *, model: _UnsetType | None | Model = UNSET, @@ -488,7 +490,7 @@ def __init__( self.deterministics = treelist() self.potentials = treelist() self.data_vars = treelist() - self._coords = {} + self._coords: StrongCoords = {} self._dim_lengths = {} self.add_coords(coords) @@ -907,9 +909,9 @@ def unobserved_RVs(self): return self.free_RVs + self.deterministics @property - def coords(self) -> dict[str, tuple | None]: + def coords(self) -> StrongCoords: """Coordinate values for model dimensions.""" - return self._coords + return self._coords @property def dim_lengths(self) -> dict[str, TensorVariable]: @@ -937,7 +939,7 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Sequence | np.ndarray | None = None, + values: CoordValue = None, *, length: int | Variable | None = None, ): From 77c2acbcdbd365144d2c39c65199a1910a9dc4be Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 22:37:22 +0530 Subject: [PATCH 02/16] Fix ruff formatting, imports, and dev requirements --- pymc/backends/arviz.py | 11 +++++------ pymc/distributions/shape_utils.py | 5 +---- pymc/model/core.py | 5 ++--- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 9910271f76..dba26582f3 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -21,14 +21,11 @@ TYPE_CHECKING, Any, Optional, + TypeAlias, Union, cast, - TypeAlias, ) -from pymc.distributions.shape_utils import StrongCoords - - import numpy as np import xarray @@ -42,6 +39,7 @@ import pymc +from pymc.distributions.shape_utils import StrongCoords from pymc.model import Model, modelcontext from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import PointFunc, extract_obs_data @@ -128,8 +126,9 @@ def find_constants(model: "Model") -> dict[str, Var]: return constant_data -def coords_and_dims_for_inferencedata(model: Model,) -> tuple[StrongCoords, DimsDict]: - +def coords_and_dims_for_inferencedata( + model: Model, +) -> tuple[StrongCoords, DimsDict]: """Parse PyMC model coords and dims format to one accepted by InferenceData.""" coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 0123e3e664..11c19de559 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -16,7 +16,7 @@ import warnings -from collections.abc import Sequence +from collections.abc import Hashable, Mapping, Sequence from functools import singledispatch from types import EllipsisType from typing import Any, TypeAlias, cast @@ -97,9 +97,6 @@ def _check_shape_type(shape): StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] -from collections.abc import Mapping -from typing import Hashable - CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None Coords: TypeAlias = Mapping[str, CoordValue] diff --git a/pymc/model/core.py b/pymc/model/core.py index 8a407e5ed2..7a58e52f66 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -20,8 +20,6 @@ import warnings from collections.abc import Iterable, Sequence -from pymc.distributions.shape_utils import Coords, StrongCoords, CoordValue - from typing import ( Literal, cast, @@ -46,6 +44,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import MinibatchOp, is_valid_observed +from pymc.distributions.shape_utils import Coords, CoordValue, StrongCoords from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -911,7 +910,7 @@ def unobserved_RVs(self): @property def coords(self) -> StrongCoords: """Coordinate values for model dimensions.""" - return self._coords + return self._coords @property def dim_lengths(self) -> dict[str, TensorVariable]: From b384bff63838585ae13b472e5738d7e1288a4bde Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 22:47:09 +0530 Subject: [PATCH 03/16] Fix circular import by importing modelcontext from pymc.model.core --- pymc/distributions/shape_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 11c19de559..e0832c3cbe 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -33,7 +33,7 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -from pymc.model import modelcontext +from pymc.model.core import modelcontext from pymc.pytensorf import convert_observed_data __all__ = [ From e7da246b808250fe48a6381cb6a492bca2912f20 Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 22:53:51 +0530 Subject: [PATCH 04/16] Fix circular import by lazily importing modelcontext in shape_from_dims --- pymc/distributions/shape_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index e0832c3cbe..cbd8935037 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -33,7 +33,7 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -from pymc.model.core import modelcontext +from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data __all__ = [ @@ -170,21 +170,27 @@ def convert_size(size: Size) -> StrongSize | None: ) -def shape_from_dims(dims: StrongDims, model) -> StrongShape: +def shape_from_dims(dims: StrongDims, model=None) -> StrongShape: """Determine shape from a `dims` tuple. Parameters ---------- dims : array-like A vector of dimension names or None. - model : pm.Model - The current model on stack. + model : pm.Model, optional + The current model on stack. If None, it will be resolved via modelcontext. Returns ------- - dims : tuple of (str or None) - Names or None for all RV dimensions. + shape : tuple + Shape inferred from model dimension lengths. """ + # Lazy import to break circular dependency + if model is None: + from pymc.model.core import modelcontext + + model = modelcontext(None) + # Dims must be known already unknowndim_dims = set(dims) - set(model.dim_lengths) if unknowndim_dims: From 3079404fa55797da5ad9eb9dbb2464658fa7a81e Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 22:58:56 +0530 Subject: [PATCH 05/16] Fix circular import by using only lazy modelcontext imports --- pymc/distributions/shape_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index cbd8935037..7a61aece33 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -33,7 +33,7 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -from pymc.model import modelcontext +#from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data __all__ = [ From 49c0db9f42fa53987b549fc0bc01307b53ebc70e Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 23:09:02 +0530 Subject: [PATCH 06/16] Fix Model circular import using TYPE_CHECKING and lazy import --- pymc/printing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/printing.py b/pymc/printing.py index 63514ac4d0..e769aeb831 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -14,8 +14,8 @@ import re - from functools import partial +from typing import TYPE_CHECKING from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant @@ -26,7 +26,9 @@ from pytensor.tensor.random.type import RandomType from pytensor.tensor.type_other import NoneTypeT -from pymc.model import Model +if TYPE_CHECKING: + from pymc.model import Model + __all__ = [ "str_for_dist", @@ -301,6 +303,7 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): try: # register our custom pretty printer in ipython shells import IPython + from pymc.model.core import Model IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) IPython.lib.pretty.for_type(Model, _default_repr_pretty) From ec222e2c1742bdab0b2def18d1e480f4bb6d8c54 Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 23:12:33 +0530 Subject: [PATCH 07/16] Fix lazy modelcontext import flagged by ruff --- pymc/distributions/shape_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 7a61aece33..cdc51b47f2 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -33,7 +33,7 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -#from pymc.model import modelcontext +# from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data __all__ = [ @@ -415,6 +415,8 @@ def get_support_shape( assert isinstance(dims, tuple) if len(dims) < ndim_supp: raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}") + from pymc.model.core import modelcontext + model = modelcontext(None) inferred_support_shape = [ model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0) From 9444946e79674f8baf44da4dca481050625ce867 Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 23:19:32 +0530 Subject: [PATCH 08/16] Fix missing modelcontext import flagged by ruff --- pymc/distributions/shape_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index cdc51b47f2..46132c23a7 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -416,7 +416,6 @@ def get_support_shape( if len(dims) < ndim_supp: raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}") from pymc.model.core import modelcontext - model = modelcontext(None) inferred_support_shape = [ model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0) From d3759b85a5e64721191eee6bd6877c7585a58343 Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Wed, 3 Dec 2025 23:25:39 +0530 Subject: [PATCH 09/16] Fix circular import of Model in printing.py --- pymc/printing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/printing.py b/pymc/printing.py index e769aeb831..8796dc8981 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -14,6 +14,7 @@ import re + from functools import partial from typing import TYPE_CHECKING @@ -303,6 +304,7 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): try: # register our custom pretty printer in ipython shells import IPython + from pymc.model.core import Model IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) From 202cb041665ab6b68b524d35f583e8d1abcd25bb Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 12:35:19 +0530 Subject: [PATCH 10/16] Move coords typing to pymc.typing and fix circular imports --- pymc/backends/arviz.py | 12 ++++++----- pymc/distributions/shape_utils.py | 9 ++------- pymc/step_methods/state.py | 2 +- pymc/typing.py | 33 +++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 13 deletions(-) create mode 100644 pymc/typing.py diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index dba26582f3..42d4e401e2 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -39,14 +39,16 @@ import pymc -from pymc.distributions.shape_utils import StrongCoords -from pymc.model import Model, modelcontext -from pymc.progress_bar import CustomProgress, default_progress_theme -from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import get_default_varnames +from pymc.model import modelcontext +from pymc.typing import StrongCoords if TYPE_CHECKING: from pymc.backends.base import MultiTrace + from pymc.model import Model + +from pymc.progress_bar import CustomProgress, default_progress_theme +from pymc.pytensorf import PointFunc, extract_obs_data +from pymc.util import get_default_varnames ___all__ = [""] diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 46132c23a7..91f49497ff 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -16,7 +16,7 @@ import warnings -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Sequence from functools import singledispatch from types import EllipsisType from typing import Any, TypeAlias, cast @@ -97,12 +97,6 @@ def _check_shape_type(shape): StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] -CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None -Coords: TypeAlias = Mapping[str, CoordValue] - -StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None -StrongCoords: TypeAlias = Mapping[str, StrongCoordValue] - def convert_dims(dims: Dims | None) -> StrongDims | None: """Process a user-provided dims variable into None or a valid dims tuple.""" @@ -416,6 +410,7 @@ def get_support_shape( if len(dims) < ndim_supp: raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}") from pymc.model.core import modelcontext + model = modelcontext(None) inferred_support_shape = [ model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0) diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 98e177aa03..3cd765624d 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -30,7 +30,7 @@ class DataClassState: def equal_dataclass_values(v1, v2): if v1.__class__ != v2.__class__: return False - if isinstance(v1, (list, tuple)): # noqa: UP038 + if isinstance(v1, (list, tuple)): return len(v1) == len(v2) and all( equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) ) diff --git a/pymc/typing.py b/pymc/typing.py new file mode 100644 index 0000000000..577675b964 --- /dev/null +++ b/pymc/typing.py @@ -0,0 +1,33 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence +from typing import TypeAlias + +import numpy as np + +# ------------------------- +# Coordinate typing helpers +# ------------------------- + +# User-facing coordinate values (before normalization) +CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None +Coords: TypeAlias = Mapping[str, CoordValue] + +# After normalization / internal representation +StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None +StrongCoords: TypeAlias = Mapping[str, StrongCoordValue] From 5d0ecacd8669412ed8d172e3a8391978baf257ba Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 12:47:03 +0530 Subject: [PATCH 11/16] Fix ruff UP038 isinstance union style --- pymc/step_methods/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 3cd765624d..6fe97de21a 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -30,7 +30,7 @@ class DataClassState: def equal_dataclass_values(v1, v2): if v1.__class__ != v2.__class__: return False - if isinstance(v1, (list, tuple)): + if isinstance(v1, list | tuple): return len(v1) == len(v2) and all( equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) ) From f0a92c5492cc3b349fd3d3731836451960be9daa Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 12:56:17 +0530 Subject: [PATCH 12/16] Fix printing Model NameError and move Coord typing to pymc.typing --- pymc/model/core.py | 2 +- pymc/printing.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 7a58e52f66..6c12b3460e 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -44,7 +44,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import MinibatchOp, is_valid_observed -from pymc.distributions.shape_utils import Coords, CoordValue, StrongCoords +from pymc.typing import Coords, CoordValue, StrongCoords from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, diff --git a/pymc/printing.py b/pymc/printing.py index 8796dc8981..180038e650 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import re from functools import partial From 569b99c6c46ed58bfeb4a08d0bbec7c3197bbf1f Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 13:09:09 +0530 Subject: [PATCH 13/16] Move coords typing to pymc.typing and fix printing imports --- pymc/backends/arviz.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 42d4e401e2..54edd3c6bf 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -13,6 +13,8 @@ # limitations under the License. """PyMC-ArviZ conversion code.""" +from __future__ import annotations + import logging import warnings @@ -20,9 +22,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeAlias, - Union, cast, ) @@ -90,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) -def find_observations(model: "Model") -> dict[str, Var]: +def find_observations(model: Model) -> dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} for obs in model.observed_RVs: @@ -107,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]: return observations -def find_constants(model: "Model") -> dict[str, Var]: +def find_constants(model: Model) -> dict[str, Var]: """If there are constants available, return them as a dictionary.""" model_vars = model.basic_RVs + model.deterministics + model.potentials value_vars = set(model.rvs_to_values.values()) @@ -272,7 +272,7 @@ def __init__( self.observations = find_observations(self.model) - def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: + def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]: """Split MultiTrace object into posterior and warmup. Returns @@ -498,7 +498,7 @@ def to_inference_data(self): def to_inference_data( - trace: Optional["MultiTrace"] = None, + trace: MultiTrace | None = None, *, prior: Mapping[str, Any] | None = None, posterior_predictive: Mapping[str, Any] | None = None, @@ -507,7 +507,7 @@ def to_inference_data( coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, - model: Optional["Model"] = None, + model: Model | None = None, save_warmup: bool | None = None, include_transformed: bool = False, ) -> InferenceData: @@ -575,8 +575,8 @@ def to_inference_data( ### perhaps we should have an inplace argument? def predictions_to_inference_data( predictions, - posterior_trace: Optional["MultiTrace"] = None, - model: Optional["Model"] = None, + posterior_trace: MultiTrace | None = None, + model: Model | None = None, coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, From b8e5ce1e03d1c15c7be4f562aaba557f3f61de4b Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 13:15:09 +0530 Subject: [PATCH 14/16] Remove implicit modelcontext fallback from shape_from_dims --- pymc/distributions/shape_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 91f49497ff..1b74f06627 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -164,26 +164,23 @@ def convert_size(size: Size) -> StrongSize | None: ) -def shape_from_dims(dims: StrongDims, model=None) -> StrongShape: +def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape: """Determine shape from a `dims` tuple. Parameters ---------- dims : array-like A vector of dimension names or None. - model : pm.Model, optional - The current model on stack. If None, it will be resolved via modelcontext. + model : pm.Model + The current model on stack. Returns ------- shape : tuple Shape inferred from model dimension lengths. """ - # Lazy import to break circular dependency if model is None: - from pymc.model.core import modelcontext - - model = modelcontext(None) + raise ValueError("model must be provided explicitly to infer shape from dims") # Dims must be known already unknowndim_dims = set(dims) - set(model.dim_lengths) From d14a85a470ecdf378be3a566c7fbfeb4be0f2d3c Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 13:25:44 +0530 Subject: [PATCH 15/16] Fix shape_from_dims typing and remove circular import --- pymc/distributions/shape_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 1b74f06627..98355267b1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -19,7 +19,7 @@ from collections.abc import Sequence from functools import singledispatch from types import EllipsisType -from typing import Any, TypeAlias, cast +from typing import TYPE_CHECKING, Any, TypeAlias, cast import numpy as np @@ -33,9 +33,12 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -# from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data +if TYPE_CHECKING: + from pymc.model import Model + + __all__ = [ "change_dist_size", "rv_size_is_none", @@ -164,7 +167,7 @@ def convert_size(size: Size) -> StrongSize | None: ) -def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape: +def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape: """Determine shape from a `dims` tuple. Parameters From fb5980e57905173a4188df10f235f5cc5adaae69 Mon Sep 17 00:00:00 2001 From: Aman Srivastava Date: Thu, 4 Dec 2025 13:27:55 +0530 Subject: [PATCH 16/16] Fix typing and import order --- pymc/model/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 6c12b3460e..fbc0ca0fb6 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -44,7 +44,6 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import MinibatchOp, is_valid_observed -from pymc.typing import Coords, CoordValue, StrongCoords from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -66,6 +65,7 @@ join_nonshared_inputs, rewrite_pregrad, ) +from pymc.typing import Coords, CoordValue, StrongCoords from pymc.util import ( UNSET, WithMemoization,