Skip to content

Commit cc0a20f

Browse files
superbobryjax authors
authored andcommitted
Raise a lowering-time error when broadcasted operand has invalid shape
Previously, we let these invalid broadcasts through, which led to crashes in Triton compiler passes, because Triton does not have a verifier checking that a tt.broadcast op is valid. PiperOrigin-RevId: 638277527
1 parent 0a5b2e1 commit cc0a20f

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value:
136136
a_type = ir.RankedTensorType(a.type)
137137
if a_type.shape == [*shape]:
138138
return a
139+
if a_type.rank != len(shape) or not all(
140+
a_type.shape[i] in (dim, 1) for i, dim in enumerate(shape)
141+
):
142+
raise ValueError(f"Cannot broadcast from {a_type.shape} to {[*shape]}")
139143
return tt_dialect.broadcast(
140144
ir.RankedTensorType.get(shape, a_type.element_type, a_type.encoding), a
141145
)

tests/pallas/pallas_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,33 @@ def load(x_ref, o_ref):
452452
x = random.normal(key, (m, n))
453453
np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5)
454454

455+
@parameterized.parameters(
456+
((16, 32), (16,)),
457+
((16, 32), (32,)),
458+
((16, 32), (16, 31)),
459+
)
460+
def test_invalid_broadcasted_load(self, x_shape, mask_shape):
461+
if self.INTERPRET:
462+
self.skipTest("No broadcasting checks in pl.load in interepreter mode")
463+
464+
@functools.partial(
465+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)
466+
)
467+
def kernel(x_ref, mask_ref, o_ref):
468+
del o_ref # Unused.
469+
pl.load(x_ref, slice(None), mask=mask_ref[:])
470+
471+
x = jnp.ones(x_shape, dtype=jnp.float32)
472+
mask = jnp.ones(mask_shape, dtype=jnp.bool_)
473+
# assertRaises* methods do not support inspecting the __cause__, so
474+
# we have to check it manually.
475+
try:
476+
kernel(x, mask)
477+
except Exception as e:
478+
self.assertIn("Cannot broadcast", str(e.__cause__))
479+
else:
480+
self.fail("Expected exception due to invalid broadcasting")
481+
455482
def test_swap(self):
456483
m, n = 16, 32
457484

0 commit comments

Comments
 (0)