Skip to content

Commit 9ec840f

Browse files
brianwa84Google-ML-Automation
authored andcommitted
[Pallas:SC] Adds a test that verifies pl.kernel outputs are placed in the proper memory spaces.
PiperOrigin-RevId: 842659708
1 parent 15ba1b7 commit 9ec840f

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/pallas/tpu_pallas_state_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
from absl.testing import absltest
17+
from absl.testing import parameterized
1718
import jax
1819
from jax._src import test_util as jtu
1920
from jax._src.state.primitives import pin, unpin
@@ -382,6 +383,28 @@ def body(x_ref, o1_ref, o2_ref, scratch_ref):
382383
np.testing.assert_array_equal(result1, x)
383384
np.testing.assert_array_equal(result2, x + 1)
384385

386+
@parameterized.named_parameters(
387+
("HBM", pltpu.HBM, 0),
388+
("VMEM", pltpu.VMEM, 1),
389+
("SMEM", pltpu.SMEM, 4),
390+
("SEMAPHORE", pltpu.SEMAPHORE, 2),
391+
)
392+
def test_kernel_with_output_memory_space(self, memory_space, color):
393+
if not jtu.is_device_tpu_at_least(5):
394+
self.skipTest("Only supported on TPU v5+")
395+
mesh = pltpu.create_tensorcore_mesh("x", num_cores=1)
396+
def body(x_ref, o_ref):
397+
pltpu.sync_copy(x_ref, o_ref)
398+
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
399+
text = pl.kernel(
400+
body, out_shape=memory_space(x.shape, x.dtype), mesh=mesh,
401+
).lower(x).as_text()
402+
custom_call = [l for l in text.split("\n") if "@tpu_custom_call" in l]
403+
self.assertLen(custom_call, 1)
404+
custom_call = custom_call[0]
405+
self.assertRegex(custom_call,
406+
r".*output_memory_colors\\22: \[" + str(color) + r"\].*")
407+
385408

386409
if __name__ == "__main__":
387410
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)