Skip to content

Commit 2a6de35

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas SC] Allow semaphores to be returned by SCS kernels
PiperOrigin-RevId: 841972235
1 parent 0339ad1 commit 2a6de35

File tree

4 files changed

+89
-8
lines changed

4 files changed

+89
-8
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
6262

6363

6464
def _get_memory_space_from_aval(
65-
out_aval: jax_core.AbstractValue,
65+
out_aval: jax_core.AbstractValue, kernel_type: tpu_core.KernelType
6666
) -> tpu_custom_call.MemorySpace | None:
6767
if not isinstance(out_aval, jax_core.ShapedArray):
6868
raise ValueError("Memory spaces not defined for non-ShapedArrays")
@@ -84,20 +84,29 @@ def _get_memory_space_from_aval(
8484
case tpu_core.MemorySpace.SMEM:
8585
return tpu_custom_call.MemorySpace.SMEM
8686
case tpu_core.MemorySpace.SEMAPHORE:
87-
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
87+
match kernel_type:
88+
case tpu_core.KernelType.SC_SCALAR_SUBCORE:
89+
return tpu_custom_call.MemorySpace.SC_SCALAR_SEMAPHORE_MEM
90+
case tpu_core.KernelType.TC:
91+
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
92+
case _:
93+
raise ValueError(f"Invalid kernel type for semaphore: {kernel_type}")
8894
case tpu_core.MemorySpace.HOST:
8995
return tpu_custom_call.MemorySpace.HOST
9096
return None
9197

9298

9399
def _get_memory_spaces_from_avals(
94-
avals: Sequence[jax_core.AbstractValue],
100+
avals: Sequence[jax_core.AbstractValue], kernel_type: tpu_core.KernelType
95101
) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None:
96102
memory_spaces = None
97103
if any(
98104
isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals
99105
):
100-
memory_spaces = tuple(map(_get_memory_space_from_aval, avals))
106+
memory_spaces = tuple(
107+
_get_memory_space_from_aval(aval, kernel_type=kernel_type)
108+
for aval in avals
109+
)
101110
return memory_spaces
102111

103112

@@ -140,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
140149
mlir_ctx.load_all_available_dialects()
141150
tpu.register_dialect(mlir_ctx)
142151

143-
match mosaic_params.kernel_type:
152+
match (kernel_type := mosaic_params.kernel_type):
144153
case tpu_core.KernelType.TC:
145154
lower_jaxpr_to_module = lowering.lower_jaxpr_to_module
146155
case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE:
@@ -191,7 +200,9 @@ def _maybe_cast_inputs(*args):
191200
# Dynamic grid bounds have to go at the front.
192201
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
193202
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
194-
output_memory_spaces = _get_memory_spaces_from_avals(out_avals)
203+
output_memory_spaces = _get_memory_spaces_from_avals(
204+
out_avals, kernel_type=kernel_type
205+
)
195206
input_memory_spaces = None
196207
if any(
197208
isinstance(aval, pallas_core.ShapedArrayWithMemorySpace)
@@ -202,7 +213,9 @@ def _maybe_cast_inputs(*args):
202213
raise NotImplementedError(
203214
"Dynamic grid bounds are not supported when specifying memory spaces for inputs."
204215
)
205-
input_memory_spaces = _get_memory_spaces_from_avals(ctx.avals_in)
216+
input_memory_spaces = _get_memory_spaces_from_avals(
217+
ctx.avals_in, kernel_type=kernel_type
218+
)
206219
if cost_estimate is not None:
207220
mosaic_cost_estimate = cast(
208221
tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate)

jax/_src/pallas/mosaic/sc_lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ def body_func(*args: ir.Value):
330330
mosaic_grid_mapping.block_mappings,
331331
):
332332
d = {}
333-
if str(arg.type.memory_space) == "#tpu.memory_space<hbm>":
333+
if (
334+
str(arg.type.memory_space) == "#tpu.memory_space<hbm>"
335+
or str(arg.type.memory_space) == "#tpu.memory_space<semaphore_mem>"
336+
):
334337
d["sc.persistent"] = ir.UnitAttr.get()
335338
if isinstance(bm, sc_core.BlockMapping) and bm.indexed_by is not None:
336339
d["sc.indexed_by"] = mlir.i32_attr(bm.indexed_by)

jax/_src/tpu_custom_call.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class MemorySpace(enum.Enum):
101101
SEMAPHORE_MEM = enum.auto()
102102
SMEM = enum.auto()
103103
HOST = enum.auto()
104+
SC_SCALAR_SEMAPHORE_MEM = enum.auto()
104105

105106
@property
106107
def color(self) -> int:
@@ -110,6 +111,8 @@ def color(self) -> int:
110111
return 1
111112
elif self == MemorySpace.SEMAPHORE_MEM:
112113
return 2
114+
elif self == MemorySpace.SC_SCALAR_SEMAPHORE_MEM:
115+
return 8
113116
elif self == MemorySpace.SMEM:
114117
return 4
115118
elif self == MemorySpace.HOST:

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,5 +1923,67 @@ class PipelineTestWithTCTiling(TCTilingMixin, PipelineTest):
19231923
pass
19241924

19251925

1926+
class PallasSparsecoreAsyncTest(PallasSCTest):
1927+
1928+
@parameterized.product(
1929+
shape=[
1930+
(8, 128),
1931+
(8, 256),
1932+
(8, 512),
1933+
(8, 1024),
1934+
(16, 128),
1935+
(16, 256),
1936+
(16, 512),
1937+
(16, 1024),
1938+
# TODO(sharadmv): These shapes fail right now.
1939+
# (64, 8),
1940+
],
1941+
dtype=[jnp.int32, jnp.float32, jnp.bfloat16],
1942+
)
1943+
def test_basic_async_kernel(self, shape, dtype):
1944+
if not jtu.is_cloud_tpu_at_least(2025, 12, 8):
1945+
self.skipTest("Need newer libtpu")
1946+
x = jnp.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
1947+
1948+
@jax.jit
1949+
def foo(x):
1950+
sc_mesh = plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1)
1951+
1952+
sem = pl.pallas_call(
1953+
lambda _: None,
1954+
out_shape=pltpu.SemaphoreType.DMA(()),
1955+
out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
1956+
compiler_params=pltpu.CompilerParams(
1957+
dimension_semantics=["core_parallel"],
1958+
kernel_type=pltpu.KernelType.SC_SCALAR_SUBCORE,
1959+
),
1960+
)()
1961+
1962+
sem_ref = jax.new_ref(sem, memory_space=pltpu.SEMAPHORE)
1963+
y_ref = pl.empty_ref_like(pltpu.HBM(x.shape, x.dtype))
1964+
x_ref = jax.new_ref(x)
1965+
1966+
run_kernel = pl.core_map(mesh=sc_mesh)
1967+
1968+
@run_kernel
1969+
def _():
1970+
pltpu.make_async_copy(x_ref, y_ref, sem_ref).start()
1971+
1972+
@run_kernel
1973+
def _():
1974+
pltpu.make_async_copy(x_ref, y_ref, sem_ref).wait()
1975+
1976+
return y_ref[...]
1977+
1978+
o = jax.block_until_ready(foo(x))
1979+
np.testing.assert_array_equal(o, x)
1980+
1981+
1982+
class PallasSparsecoreAsyncTestWithTCTiling(
1983+
TCTilingMixin, PallasSparsecoreAsyncTest
1984+
):
1985+
pass
1986+
1987+
19261988
if __name__ == "__main__":
19271989
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)