|
14 | 14 |
|
15 | 15 | import functools |
16 | 16 | from absl.testing import absltest |
| 17 | +from absl.testing import parameterized |
17 | 18 | import jax |
18 | 19 | from jax._src import test_util as jtu |
19 | 20 | from jax._src.state.primitives import pin, unpin |
@@ -382,6 +383,28 @@ def body(x_ref, o1_ref, o2_ref, scratch_ref): |
382 | 383 | np.testing.assert_array_equal(result1, x) |
383 | 384 | np.testing.assert_array_equal(result2, x + 1) |
384 | 385 |
|
| 386 | + @parameterized.named_parameters( |
| 387 | + ("HBM", pltpu.HBM, 0), |
| 388 | + ("VMEM", pltpu.VMEM, 1), |
| 389 | + ("SMEM", pltpu.SMEM, 4), |
| 390 | + ("SEMAPHORE", pltpu.SEMAPHORE, 2), |
| 391 | + ) |
| 392 | + def test_kernel_with_output_memory_space(self, memory_space, color): |
| 393 | + if not jtu.is_device_tpu_at_least(5): |
| 394 | + self.skipTest("Only supported on TPU v5+") |
| 395 | + mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) |
| 396 | + def body(x_ref, o_ref): |
| 397 | + pltpu.sync_copy(x_ref, o_ref) |
| 398 | + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) |
| 399 | + text = pl.kernel( |
| 400 | + body, out_shape=memory_space(x.shape, x.dtype), mesh=mesh, |
| 401 | + ).lower(x).as_text() |
| 402 | + custom_call = [l for l in text.split("\n") if "@tpu_custom_call" in l] |
| 403 | + self.assertLen(custom_call, 1) |
| 404 | + custom_call = custom_call[0] |
| 405 | + self.assertRegex(custom_call, |
| 406 | + r".*output_memory_colors\\22: \[" + str(color) + r"\].*") |
| 407 | + |
385 | 408 |
|
386 | 409 | if __name__ == "__main__": |
387 | 410 | absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments