|
19 | 19 | import pymc as pm |
20 | 20 |
|
21 | 21 | from pymc.logprob.basic import _logprob |
22 | | -from pymc.distributions.transforms import RVTransform |
| 22 | +from pymc.distributions.transforms import Transform |
23 | 23 | from pytensor import tensor as pt |
24 | 24 | from pytensor.graph.basic import Apply |
25 | 25 | from pytensor.graph.op import Op |
26 | 26 | from pytensor.raise_op import Assert |
27 | 27 | from pytensor.tensor.random.op import RandomVariable |
28 | | -from pymc.distributions.continuous import assert_negative_support |
| 28 | +# assert_negative_support was removed in PyMC v5, we use Assert directly |
29 | 29 | from pymc.distributions.dist_math import check_parameters, factln, logpow |
30 | 30 | from pymc.distributions.distribution import _moment |
31 | 31 | from pymc.distributions.shape_utils import rv_size_is_none |
@@ -69,7 +69,7 @@ def grad(self, inputs, gradients): |
69 | 69 | ballBackwardOp = BallBackwardOp() |
70 | 70 |
|
71 | 71 |
|
72 | | -class BallTransform(RVTransform): |
| 72 | +class BallTransform(Transform): |
73 | 73 | name = "ball" |
74 | 74 |
|
75 | 75 | def backward(self, value, *inputs): |
@@ -109,7 +109,7 @@ class HyperballUniformRV(RandomVariable): |
109 | 109 |
|
110 | 110 | def make_node(self, rng, size, dtype, dim, alpha): |
111 | 111 | alpha = pt.as_tensor_variable(alpha) |
112 | | - dim = pt.as_tensor_variable(pm.aesaraf.intX(dim)) |
| 112 | + dim = pt.as_tensor_variable(pm.pytensorf.intX(dim)) |
113 | 113 | if dim.ndim > 0: |
114 | 114 | raise ValueError("dim must be a scalar variable (ndim=0).") |
115 | 115 | msg = "HyperballUniform dim parameter must be greater than 1" |
@@ -144,24 +144,26 @@ class HyperballUniform(pm.distributions.Continuous): |
144 | 144 | @classmethod |
145 | 145 | def dist(cls, dim, alpha=1.0, no_assert: bool = False, **kwargs): |
146 | 146 | if not no_assert: |
147 | | - alpha = assert_negative_support(alpha, "alpha", "HyperballUniform") |
| 147 | + # Assert alpha > 0 (positive support) |
| 148 | + alpha = pt.as_tensor_variable(alpha) |
| 149 | + alpha = Assert("alpha must be positive")(alpha, pt.gt(alpha, 0)) |
148 | 150 | return super().dist([dim, alpha], **kwargs) |
149 | 151 |
|
| 152 | + @staticmethod |
| 153 | + def rv_op_moment(rv, size, dim, alpha): |
| 154 | + """Define the moment (initial point) for the RV""" |
| 155 | + moment = pt.ones((dim,), dtype=rv.dtype) * 0.5 / pt.sqrt(dim) |
| 156 | + if not rv_size_is_none(size): |
| 157 | + moment_size = pt.concatenate([size, [dim]]) |
| 158 | + moment = pt.full(moment_size, moment) |
| 159 | + return moment |
| 160 | + |
150 | 161 |
|
151 | 162 | @_default_transform.register(HyperballUniformRV) |
152 | 163 | def ball_transform(op, rv): |
153 | 164 | return ballTransform |
154 | 165 |
|
155 | 166 |
|
156 | | -@_moment.register(HyperballUniformRV) |
157 | | -def moment(op, rv, rng, size, dtype, dim, alpha): |
158 | | - moment = pt.ones((dim,), dtype=dtype) * 0.5 / pt.sqrt(dim) |
159 | | - if not rv_size_is_none(size): |
160 | | - moment_size = pt.concatenate([size, [dim]]) |
161 | | - moment = pt.full(moment_size, moment) |
162 | | - return moment |
163 | | - |
164 | | - |
165 | 167 | @_logprob.register(HyperballUniformRV) |
166 | 168 | def logp(op, value_var_list, rng, size, dtype, dim, alpha, **kwargs): |
167 | 169 | value = value_var_list[0] |
|
0 commit comments