Skip to content

Commit 548eaa5

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Allow closed over scalars in core_map code
This allows doing things like dynamic indexing of Refs using just regular scalars from outside the kernel. PiperOrigin-RevId: 842429684
1 parent f374387 commit 548eaa5

File tree

5 files changed

+105
-7
lines changed

5 files changed

+105
-7
lines changed

jax/_src/pallas/core.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,8 +1535,13 @@ def default_mesh_discharge_rule(
15351535
scratch_shapes,
15361536
):
15371537
"""Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
1538-
del out_avals # Unused.
15391538
default_memory_space = memory_space
1539+
if not all(
1540+
isinstance(aval, state.AbstractRef) for aval in (in_avals + out_avals)
1541+
):
1542+
raise ValueError(
1543+
"default_mesh_discharge_rule only supports Ref inputs/outputs."
1544+
)
15401545

15411546
def body(*args):
15421547
# Due to aliasing, ``args`` contains aliased inputs and outputs so we
@@ -1605,15 +1610,24 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, debug_info,
16051610
for var in jaxpr.constvars
16061611
if not isinstance(aval := var.aval, state.AbstractRef)
16071612
]
1608-
if consts_avals:
1613+
is_scalar_const_aval = [
1614+
isinstance(aval, jax_core.ShapedArray) and not aval.shape
1615+
for aval in consts_avals
1616+
]
1617+
if not all(is_scalar_const_aval):
16091618
ctx = jax_core.JaxprPpContext()
1610-
pp_const_avals = ", ".join(
1611-
jax_core.pp_aval(aval, ctx) for aval in consts_avals
1619+
non_scalar_const_avals = [
1620+
aval
1621+
for aval, is_scalar in zip(consts_avals, is_scalar_const_aval)
1622+
if not is_scalar
1623+
]
1624+
non_scalar_const_pp_avals = ", ".join(
1625+
jax_core.pp_aval(aval, ctx) for aval in non_scalar_const_avals
16121626
)
16131627
raise ValueError(
16141628
"The kernel function in core_map"
1615-
f" {debug_info.func_src_info} captures constants"
1616-
f" [{pp_const_avals}]. You should pass them as inputs."
1629+
f" {debug_info.func_src_info} captures non-scalar constants"
1630+
f" [{non_scalar_const_pp_avals}]. You should pass them as inputs."
16171631
)
16181632
return _core_map_mesh_rules[type(mesh)](
16191633
in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs

jax/_src/pallas/mosaic/core.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import jax.numpy as jnp
2727
from jax.extend import backend as jex_backend
2828
from jax._src import core as jax_core
29+
from jax._src import linear_util as lu
2930
from jax._src import state
3031
from jax._src import util
3132
from jax._src.frozen_dict import FrozenDict
33+
from jax._src.interpreters import partial_eval as pe
3234
from jax._src.pallas import core as pallas_core
3335
import numpy as np
3436

@@ -336,6 +338,49 @@ def _tensorcore_mesh_discharge_rule(
336338
"TensorCoreMesh does not support VMEM inputs/outputs when there are"
337339
" >1 cores. Use HBM or ANY instead."
338340
)
341+
def allowed_aval(aval):
342+
if isinstance(aval, state.AbstractRef):
343+
return True
344+
if isinstance(aval, jax_core.ShapedArray):
345+
# Only scalars are allowed.
346+
return not aval.shape
347+
return False
348+
assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars)
349+
350+
is_scalar_const = [
351+
isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape
352+
for v in jaxpr.constvars
353+
]
354+
if any(is_scalar_const):
355+
# Rewrite body jaxpr to take in scalar values as Refs.
356+
def new_body(*args):
357+
args = [
358+
a[0] if is_scalar else a
359+
for a, is_scalar in zip(args, is_scalar_const)
360+
]
361+
return jax_core.eval_jaxpr(jaxpr, args)
362+
# TODO(sharadmv): Remove this once Mosaic support passing scalars as values.
363+
new_trace_avals = [
364+
state.AbstractRef( # pylint: disable=g-long-ternary
365+
jax_core.ShapedArray((1,), v.aval.dtype),
366+
memory_space=MemorySpace.SMEM,
367+
)
368+
if is_scalar
369+
else v.aval
370+
for v, is_scalar in zip(jaxpr.constvars, is_scalar_const)
371+
]
372+
new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
373+
lu.wrap_init(
374+
new_body, debug_info=jaxpr.debug_info.with_unknown_names()
375+
),
376+
new_trace_avals,
377+
)
378+
jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars)
379+
args = tuple(
380+
a[None] if is_scalar else a
381+
for a, is_scalar in zip(args, is_scalar_const)
382+
)
383+
in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)])
339384
return pallas_core.default_mesh_discharge_rule(
340385
in_avals,
341386
out_avals,

jax/_src/pallas/mosaic/sc_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def _scalar_subcore_mesh_discharge_rule(
219219
compiler_params = tpu_core.CompilerParams()
220220
if compiler_params.dimension_semantics is not None:
221221
raise ValueError("ScalarSubcoreMesh does not support dimension_semantics=")
222+
sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)]
223+
if sa_avals:
224+
raise NotImplementedError(
225+
f"Cannot close over values in core_map: {sa_avals}"
226+
)
222227
return pallas_core.default_mesh_discharge_rule(
223228
in_avals,
224229
out_avals,

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,11 @@ def _gpu_mesh_discharge_rule(
13681368
)
13691369
if not compiler_params:
13701370
compiler_params = CompilerParams()
1371+
sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)]
1372+
if sa_avals:
1373+
raise NotImplementedError(
1374+
f"Cannot close over values in core_map: {sa_avals}"
1375+
)
13711376
return pallas_core.default_mesh_discharge_rule(
13721377
in_avals,
13731378
out_avals,

tests/pallas/tpu_pallas_state_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,38 @@ def kernel(x_ref, out_ref, tmp_ref):
324324
return kernel(x)
325325

326326
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
327-
with self.assertRaisesRegex(Exception, "core_map .* captures constants"):
327+
with self.assertRaisesRegex(
328+
Exception, "core_map .* captures non-scalar constants"
329+
):
328330
f(x)
329331

332+
def test_capture_scalar(self):
333+
@jax.jit
334+
def f(x, i):
335+
@pl.kernel(out_shape=jax.ShapeDtypeStruct(x.shape[1:], jnp.int32),
336+
mesh=pltpu.create_tensorcore_mesh("x", num_cores=1))
337+
def kernel(x_ref, out_ref):
338+
pltpu.sync_copy(x_ref.at[i], out_ref)
339+
return kernel(x)
340+
341+
x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128))
342+
for i in range(x.shape[0]):
343+
out = f(x, i)
344+
np.testing.assert_array_equal(out, x[i])
345+
346+
@jax.jit
347+
def g(x, i):
348+
@pl.kernel(out_shape=jax.ShapeDtypeStruct((2, *x.shape[1:]), jnp.int32),
349+
mesh=pltpu.create_tensorcore_mesh("x", num_cores=1))
350+
def kernel(x_ref, out_ref):
351+
pltpu.sync_copy(x_ref.at[pl.ds(i, 2)], out_ref)
352+
return kernel(x)
353+
354+
x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128))
355+
for i in range(3):
356+
out = g(x, i)
357+
np.testing.assert_array_equal(out, x[i:i+2])
358+
330359
def test_kernel_helper_with_scratch(self):
331360
mesh = pltpu.create_tensorcore_mesh("x")
332361
def body(x_ref, o_ref, scratch_ref):

0 commit comments

Comments
 (0)