diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index f3532c895a..87a62cad81 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2627,6 +2627,11 @@ def as_index_variable(idx): idx = as_tensor_variable(idx) if idx.type.dtype not in discrete_dtypes: raise TypeError("index must be integers or a boolean mask") + if idx.type.dtype == "bool" and idx.type.ndim == 0: + raise NotImplementedError( + "Boolean scalar indexing not implemented. " + "Open an issue in https://github.com/pymc-devs/pytensor/issues if you need this behavior." + ) return idx diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 261a8bbc4a..145f1077d3 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -598,7 +598,7 @@ def is_empty_array(val): def __setitem__(self, key, value): raise TypeError( - "TensorVariable does not support item assignment. Use the output of `set` or `add` instead." + "TensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." ) def take(self, indices, axis=None, mode="raise"): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index aebd60de56..7b3f9af617 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -2228,6 +2228,11 @@ def fun(x, y): mode=self.mode, ) + def test_boolean_scalar_raises(self): + x = vector("x") + with pytest.raises(NotImplementedError): + x[np.array(True)] + class TestInferShape(utt.InferShapeTester): @staticmethod diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 50c36a05fc..57e47ce064 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -1,3 +1,4 @@ +import re from copy import copy import numpy as np @@ -444,7 +445,7 @@ def test_set_inc(self): def test_set_item_error(self): x = matrix("x") - msg = "Use the output of `set` or `add` instead." + msg = re.escape("Use the output of `x[idx].set` or `x[idx].inc` instead.") with pytest.raises(TypeError, match=msg): x[0] = 5 with pytest.raises(TypeError, match=msg):