Skip to content

Commit e7da246

Browse files
committed
Fix circular import by lazily importing modelcontext in shape_from_dims
1 parent b384bff commit e7da246

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

pymc/distributions/shape_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pytensor.tensor.type_other import NoneTypeT
3434
from pytensor.tensor.variable import TensorVariable
3535

36-
from pymc.model.core import modelcontext
36+
from pymc.model import modelcontext
3737
from pymc.pytensorf import convert_observed_data
3838

3939
__all__ = [
@@ -170,21 +170,27 @@ def convert_size(size: Size) -> StrongSize | None:
170170
)
171171

172172

173-
def shape_from_dims(dims: StrongDims, model) -> StrongShape:
173+
def shape_from_dims(dims: StrongDims, model=None) -> StrongShape:
174174
"""Determine shape from a `dims` tuple.
175175
176176
Parameters
177177
----------
178178
dims : array-like
179179
A vector of dimension names or None.
180-
model : pm.Model
181-
The current model on stack.
180+
model : pm.Model, optional
181+
The current model on stack. If None, it will be resolved via modelcontext.
182182
183183
Returns
184184
-------
185-
dims : tuple of (str or None)
186-
Names or None for all RV dimensions.
185+
shape : tuple
186+
Shape inferred from model dimension lengths.
187187
"""
188+
# Lazy import to break circular dependency
189+
if model is None:
190+
from pymc.model.core import modelcontext
191+
192+
model = modelcontext(None)
193+
188194
# Dims must be known already
189195
unknowndim_dims = set(dims) - set(model.dim_lengths)
190196
if unknowndim_dims:

0 commit comments

Comments
 (0)