Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
"""PyMC-ArviZ conversion code."""

from __future__ import annotations

import logging
import warnings

from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
TypeAlias,
cast,
)

Expand All @@ -38,13 +39,16 @@

import pymc

from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames
from pymc.model import modelcontext
from pymc.typing import StrongCoords

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
from pymc.model import Model

from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames

___all__ = [""]

Expand All @@ -56,6 +60,7 @@

# random variable object ...
Var = Any
DimsDict: TypeAlias = Mapping[str, Sequence[str]]


def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
Expand Down Expand Up @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)


def find_observations(model: "Model") -> dict[str, Var]:
def find_observations(model: Model) -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
Expand All @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
return observations


def find_constants(model: "Model") -> dict[str, Var]:
def find_constants(model: Model) -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())
Expand All @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
def coords_and_dims_for_inferencedata(
model: Model,
) -> tuple[StrongCoords, DimsDict]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
Expand Down Expand Up @@ -265,7 +272,7 @@ def __init__(

self.observations = find_observations(self.model)

def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]:
"""Split MultiTrace object into posterior and warmup.

Returns
Expand Down Expand Up @@ -491,7 +498,7 @@ def to_inference_data(self):


def to_inference_data(
trace: Optional["MultiTrace"] = None,
trace: MultiTrace | None = None,
*,
prior: Mapping[str, Any] | None = None,
posterior_predictive: Mapping[str, Any] | None = None,
Expand All @@ -500,7 +507,7 @@ def to_inference_data(
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model: Optional["Model"] = None,
model: Model | None = None,
save_warmup: bool | None = None,
include_transformed: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -568,8 +575,8 @@ def to_inference_data(
### perhaps we should have an inplace argument?
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
posterior_trace: MultiTrace | None = None,
model: Model | None = None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
Expand Down
18 changes: 13 additions & 5 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Sequence
from functools import singledispatch
from types import EllipsisType
from typing import Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, TypeAlias, cast

import numpy as np

Expand All @@ -33,9 +33,12 @@
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable

from pymc.model import modelcontext
from pymc.pytensorf import convert_observed_data

if TYPE_CHECKING:
from pymc.model import Model


__all__ = [
"change_dist_size",
"rv_size_is_none",
Expand Down Expand Up @@ -164,7 +167,7 @@ def convert_size(size: Size) -> StrongSize | None:
)


def shape_from_dims(dims: StrongDims, model) -> StrongShape:
def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape:
"""Determine shape from a `dims` tuple.

Parameters
Expand All @@ -176,9 +179,12 @@ def shape_from_dims(dims: StrongDims, model) -> StrongShape:

Returns
-------
dims : tuple of (str or None)
Names or None for all RV dimensions.
shape : tuple
Shape inferred from model dimension lengths.
"""
if model is None:
raise ValueError("model must be provided explicitly to infer shape from dims")

# Dims must be known already
unknowndim_dims = set(dims) - set(model.dim_lengths)
if unknowndim_dims:
Expand Down Expand Up @@ -403,6 +409,8 @@ def get_support_shape(
assert isinstance(dims, tuple)
if len(dims) < ndim_supp:
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
from pymc.model.core import modelcontext

model = modelcontext(None)
inferred_support_shape = [
model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0)
Expand Down
9 changes: 5 additions & 4 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
join_nonshared_inputs,
rewrite_pregrad,
)
from pymc.typing import Coords, CoordValue, StrongCoords
from pymc.util import (
UNSET,
WithMemoization,
Expand Down Expand Up @@ -453,7 +454,7 @@ def _validate_name(name):
def __init__(
self,
name="",
coords=None,
coords: Coords | None = None,
check_bounds=True,
*,
model: _UnsetType | None | Model = UNSET,
Expand Down Expand Up @@ -488,7 +489,7 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._coords: StrongCoords = {}
self._dim_lengths = {}
self.add_coords(coords)

Expand Down Expand Up @@ -907,7 +908,7 @@ def unobserved_RVs(self):
return self.free_RVs + self.deterministics

@property
def coords(self) -> dict[str, tuple | None]:
def coords(self) -> StrongCoords:
"""Coordinate values for model dimensions."""
return self._coords

Expand Down Expand Up @@ -937,7 +938,7 @@ def shape_from_dims(self, dims):
def add_coord(
self,
name: str,
values: Sequence | np.ndarray | None = None,
values: CoordValue = None,
*,
length: int | Variable | None = None,
):
Expand Down
9 changes: 8 additions & 1 deletion pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.


from __future__ import annotations

import re

from functools import partial
from typing import TYPE_CHECKING

from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant
Expand All @@ -26,7 +29,9 @@
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.type_other import NoneTypeT

from pymc.model import Model
if TYPE_CHECKING:
from pymc.model import Model


__all__ = [
"str_for_dist",
Expand Down Expand Up @@ -302,6 +307,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
# register our custom pretty printer in ipython shells
import IPython

from pymc.model.core import Model

IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty)
IPython.lib.pretty.for_type(Model, _default_repr_pretty)
except (ModuleNotFoundError, AttributeError):
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassState:
def equal_dataclass_values(v1, v2):
if v1.__class__ != v2.__class__:
return False
if isinstance(v1, (list, tuple)): # noqa: UP038
if isinstance(v1, list | tuple):
return len(v1) == len(v2) and all(
equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
)
Expand Down
33 changes: 33 additions & 0 deletions pymc/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

from collections.abc import Hashable, Mapping, Sequence
from typing import TypeAlias

import numpy as np

# -------------------------
# Coordinate typing helpers
# -------------------------

# User-facing coordinate values (before normalization)
CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
Coords: TypeAlias = Mapping[str, CoordValue]

# After normalization / internal representation
StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]