Skip to content

Commit d14a85a

Browse files
committed
Fix shape_from_dims typing and remove circular import
1 parent b8e5ce1 commit d14a85a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc/distributions/shape_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections.abc import Sequence
2020
from functools import singledispatch
2121
from types import EllipsisType
22-
from typing import Any, TypeAlias, cast
22+
from typing import TYPE_CHECKING, Any, TypeAlias, cast
2323

2424
import numpy as np
2525

@@ -33,9 +33,12 @@
3333
from pytensor.tensor.type_other import NoneTypeT
3434
from pytensor.tensor.variable import TensorVariable
3535

36-
# from pymc.model import modelcontext
3736
from pymc.pytensorf import convert_observed_data
3837

38+
if TYPE_CHECKING:
39+
from pymc.model import Model
40+
41+
3942
__all__ = [
4043
"change_dist_size",
4144
"rv_size_is_none",
@@ -164,7 +167,7 @@ def convert_size(size: Size) -> StrongSize | None:
164167
)
165168

166169

167-
def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape:
170+
def shape_from_dims(dims: StrongDims, model: "Model") -> StrongShape:
168171
"""Determine shape from a `dims` tuple.
169172
170173
Parameters

0 commit comments

Comments
 (0)