diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..54edd3c6bf 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -13,6 +13,8 @@ # limitations under the License. """PyMC-ArviZ conversion code.""" +from __future__ import annotations + import logging import warnings @@ -20,8 +22,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, + TypeAlias, cast, ) @@ -38,13 +39,16 @@ import pymc -from pymc.model import Model, modelcontext -from pymc.progress_bar import CustomProgress, default_progress_theme -from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import get_default_varnames +from pymc.model import modelcontext +from pymc.typing import StrongCoords if TYPE_CHECKING: from pymc.backends.base import MultiTrace + from pymc.model import Model + +from pymc.progress_bar import CustomProgress, default_progress_theme +from pymc.pytensorf import PointFunc, extract_obs_data +from pymc.util import get_default_varnames ___all__ = [""] @@ -56,6 +60,7 @@ # random variable object ... Var = Any +DimsDict: TypeAlias = Mapping[str, Sequence[str]] def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) -def find_observations(model: "Model") -> dict[str, Var]: +def find_observations(model: Model) -> dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} for obs in model.observed_RVs: @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]: return observations -def find_constants(model: "Model") -> dict[str, Var]: +def find_constants(model: Model) -> dict[str, Var]: """If there are constants available, return them as a dictionary.""" model_vars = model.basic_RVs + model.deterministics + model.potentials value_vars = set(model.rvs_to_values.values()) @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]: return constant_data -def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]: +def coords_and_dims_for_inferencedata( + model: Model, +) -> tuple[StrongCoords, DimsDict]: """Parse PyMC model coords and dims format to one accepted by InferenceData.""" coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals @@ -265,7 +272,7 @@ def __init__( self.observations = find_observations(self.model) - def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: + def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]: """Split MultiTrace object into posterior and warmup. Returns @@ -491,7 +498,7 @@ def to_inference_data(self): def to_inference_data( - trace: Optional["MultiTrace"] = None, + trace: MultiTrace | None = None, *, prior: Mapping[str, Any] | None = None, posterior_predictive: Mapping[str, Any] | None = None, @@ -500,7 +507,7 @@ def to_inference_data( coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, - model: Optional["Model"] = None, + model: Model | None = None, save_warmup: bool | None = None, include_transformed: bool = False, ) -> InferenceData: @@ -568,8 +575,8 @@ def to_inference_data( ### perhaps we should have an inplace argument? def predictions_to_inference_data( predictions, - posterior_trace: Optional["MultiTrace"] = None, - model: Optional["Model"] = None, + posterior_trace: MultiTrace | None = None, + model: Model | None = None, coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: list | None = None, diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index cdab3046b1..98355267b1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -19,7 +19,7 @@ from collections.abc import Sequence from functools import singledispatch from types import EllipsisType -from typing import Any, TypeAlias, cast +from typing import TYPE_CHECKING, Any, TypeAlias, cast import numpy as np @@ -33,9 +33,12 @@ from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data +if TYPE_CHECKING: + from pymc.model import Model + + __all__ = [ "change_dist_size", "rv_size_is_none", @@ -164,7 +167,7 @@ def convert_size(size: Size) -> StrongSize | None: ) -def shape_from_dims(dims: StrongDims, model) -> StrongShape: +def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape: """Determine shape from a `dims` tuple. Parameters @@ -176,9 +179,12 @@ def shape_from_dims(dims: StrongDims, model) -> StrongShape: Returns ------- - dims : tuple of (str or None) - Names or None for all RV dimensions. + shape : tuple + Shape inferred from model dimension lengths. """ + if model is None: + raise ValueError("model must be provided explicitly to infer shape from dims") + # Dims must be known already unknowndim_dims = set(dims) - set(model.dim_lengths) if unknowndim_dims: @@ -403,6 +409,8 @@ def get_support_shape( assert isinstance(dims, tuple) if len(dims) < ndim_supp: raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}") + from pymc.model.core import modelcontext + model = modelcontext(None) inferred_support_shape = [ model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0) diff --git a/pymc/model/core.py b/pymc/model/core.py index 3630138e00..fbc0ca0fb6 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -65,6 +65,7 @@ join_nonshared_inputs, rewrite_pregrad, ) +from pymc.typing import Coords, CoordValue, StrongCoords from pymc.util import ( UNSET, WithMemoization, @@ -453,7 +454,7 @@ def _validate_name(name): def __init__( self, name="", - coords=None, + coords: Coords | None = None, check_bounds=True, *, model: _UnsetType | None | Model = UNSET, @@ -488,7 +489,7 @@ def __init__( self.deterministics = treelist() self.potentials = treelist() self.data_vars = treelist() - self._coords = {} + self._coords: StrongCoords = {} self._dim_lengths = {} self.add_coords(coords) @@ -907,7 +908,7 @@ def unobserved_RVs(self): return self.free_RVs + self.deterministics @property - def coords(self) -> dict[str, tuple | None]: + def coords(self) -> StrongCoords: """Coordinate values for model dimensions.""" return self._coords @@ -937,7 +938,7 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Sequence | np.ndarray | None = None, + values: CoordValue = None, *, length: int | Variable | None = None, ): diff --git a/pymc/printing.py b/pymc/printing.py index 63514ac4d0..180038e650 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,9 +13,12 @@ # limitations under the License. +from __future__ import annotations + import re from functools import partial +from typing import TYPE_CHECKING from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant @@ -26,7 +29,9 @@ from pytensor.tensor.random.type import RandomType from pytensor.tensor.type_other import NoneTypeT -from pymc.model import Model +if TYPE_CHECKING: + from pymc.model import Model + __all__ = [ "str_for_dist", @@ -302,6 +307,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): # register our custom pretty printer in ipython shells import IPython + from pymc.model.core import Model + IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) IPython.lib.pretty.for_type(Model, _default_repr_pretty) except (ModuleNotFoundError, AttributeError): diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 98e177aa03..6fe97de21a 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -30,7 +30,7 @@ class DataClassState: def equal_dataclass_values(v1, v2): if v1.__class__ != v2.__class__: return False - if isinstance(v1, (list, tuple)): # noqa: UP038 + if isinstance(v1, list | tuple): return len(v1) == len(v2) and all( equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) ) diff --git a/pymc/typing.py b/pymc/typing.py new file mode 100644 index 0000000000..577675b964 --- /dev/null +++ b/pymc/typing.py @@ -0,0 +1,33 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence +from typing import TypeAlias + +import numpy as np + +# ------------------------- +# Coordinate typing helpers +# ------------------------- + +# User-facing coordinate values (before normalization) +CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None +Coords: TypeAlias = Mapping[str, CoordValue] + +# After normalization / internal representation +StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None +StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]