Skip to content

Commit b4ed2f9

Browse files
Davis YoshidaGoogle-ML-Automation
authored andcommitted
Support Hijax types in emit_pipeline.
PiperOrigin-RevId: 841158497
1 parent 6a1397e commit b4ed2f9

File tree

10 files changed

+490
-49
lines changed

10 files changed

+490
-49
lines changed

jax/_src/pallas/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue):
13651365
core_map_p = jax_core.Primitive("core_map")
13661366
core_map_p.multiple_results = True
13671367

1368+
def _core_map_is_high(*avals, jaxpr, **params):
1369+
del avals, params
1370+
return jaxpr.is_high
1371+
core_map_p.is_high = _core_map_is_high # type: ignore[method-assign]
1372+
1373+
def _core_map_to_lojax(*consts, jaxpr, mesh, **params):
1374+
closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
1375+
with (
1376+
tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()),
1377+
jax_core.extend_axis_env_nd(mesh.shape.items()),
1378+
):
1379+
closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
1380+
assert not closed_lo_jaxpr.is_high
1381+
return core_map_p.bind(
1382+
*closed_lo_jaxpr.consts,
1383+
jaxpr=closed_lo_jaxpr.jaxpr,
1384+
mesh=mesh,
1385+
**params,
1386+
)
1387+
core_map_p.to_lojax = _core_map_to_lojax
1388+
13681389

13691390
def core_map(
13701391
mesh,

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
from typing import Any, Union
2424

2525
import jax
26+
from jax import core as jax_core
2627
from jax import lax
2728
from jax import tree_util
2829
from jax._src import util as jax_util
2930
from jax._src.pallas import core as pallas_core
3031
from jax._src.pallas import primitives as primitives
3132
from jax._src.pallas.mosaic import core as tpu_core
3233
from jax._src.pallas.mosaic import helpers as tpu_helpers
33-
from jax._src.pallas.mosaic import tpu_info
3434
from jax._src.pallas.mosaic import primitives as tpu_primitives
35+
from jax._src.pallas.mosaic import tpu_info
36+
from jax._src.state import types as state_types
3537
from jax.experimental import pallas as pl
3638
import jax.numpy as jnp
37-
import numpy as np
3839

3940

4041
SMEM = tpu_core.MemorySpace.SMEM
@@ -79,17 +80,32 @@ def add_leaves(i, x):
7980
def _get_tpu_generation() -> int:
8081
return tpu_info.get_tpu_info().generation
8182

82-
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
83-
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
84-
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
85-
# (2, 3, 128, 128) -> (1, 1, 8, 128).
83+
84+
def _make_tiling(
85+
shape: tuple[int, ...], ty: jax_core.AbstractValue
86+
) -> tuple[int | None, ...]:
87+
"""Compute a tiling for the given shape and type.
88+
89+
For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
90+
and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
91+
(2, 3, 128, 128) -> (1, 1, 8, 128).
92+
93+
Types are not required to have a dtype, so for such types we return None for
94+
all dimensions because their tiling is unknown.
95+
"""
96+
8697
if len(shape) < 2:
8798
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
99+
100+
if not hasattr(ty, 'dtype'):
101+
return (None,) * len(shape)
102+
88103
leading_dims, final_dims = shape[:-2], shape[-2:]
89104
# We want to find the minimum power of 2 that fits the second-minor dimension
90105
# of shape, with maximum value 8.
91106
second_minor, _ = final_dims
92-
packing = 4 // dtype.itemsize
107+
108+
packing = 4 // ty.dtype.itemsize
93109
max_tiling = _TILING[0]
94110
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
95111
while second_minor_tiling < min(second_minor, max_tiling):
@@ -114,17 +130,23 @@ def _make_block_ds(
114130
assert isinstance(out, pl.Slice)
115131
return out
116132

117-
def _create_blocked_slice(block_index: jax.Array | int,
118-
block_size: int,
119-
dim_size: int,
120-
tiling: int):
133+
134+
def _create_blocked_slice(
135+
block_index: jax.Array | int,
136+
block_size: int,
137+
dim_size: int,
138+
tiling: int | None,
139+
):
121140
block_start = block_size * block_index
122141
if (dim_rem := dim_size % block_size) == 0:
123142
return pl.ds(block_start, block_size)
143+
if tiling is None:
144+
raise ValueError("If tiling is None, block_size must divide dim_size.")
124145
if block_size % tiling != 0:
125146
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
126147
num_blocks = pl.cdiv(dim_size, block_size)
127148
is_last = block_index == num_blocks - 1
149+
128150
rounded_size = jnp.where(
129151
is_last,
130152
_round_up_to_nearest_multiple(dim_rem % block_size, tiling),
@@ -133,20 +155,26 @@ def _create_blocked_slice(block_index: jax.Array | int,
133155
rounded_size = pl.multiple_of(rounded_size, tiling)
134156
return pl.ds(block_index * block_size, rounded_size)
135157

158+
136159
def _create_bounded_slice(slice_start: jax.Array | int,
137160
slice_size: jax.Array | int,
138161
block_size: int,
139162
dim_size: int,
140-
tiling: int):
141-
if block_size % tiling != 0:
163+
tiling: int | None):
164+
if tiling is not None and block_size % tiling != 0:
142165
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
166+
143167
# We assume by construction that slice_size <= block_size. We also assume
144168
# that the slice_start is already aligned to the tiling.
145169

170+
if tiling is None:
171+
return pl.ds(slice_start, slice_size)
172+
146173
# If we are out of bound, we need to round the slice size down to the nearest
147174
# multiple of the tiling.
148175
is_oob = slice_start + slice_size > dim_size
149176
remaining = dim_size - slice_start
177+
150178
rounded_size = jnp.where(
151179
is_oob,
152180
_round_up_to_nearest_multiple(remaining, tiling),
@@ -157,7 +185,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,
157185

158186
def _make_block_slice(
159187
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
160-
tiling: int
188+
tiling: int | None
161189
) -> pl.Slice | slice | int | jax.Array:
162190
# Computes a slice given a block index and block size. In the default case,
163191
# we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +360,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332360
def compute_index(self):
333361
return self.spec.index_map
334362

335-
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
363+
def get_dma_slice(self, src_ty, grid_indices):
336364
# We need to handle blocks that might go OOB in the src array. An in bounds
337365
# block looks like this (for array shape (600, 600) and block shape
338366
# (256, 256)):
@@ -379,10 +407,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379407
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380408
# for the last iteration on that dimension, we will pick the next highest
381409
# tile multiple, i.e. (96, 256).
410+
411+
if (src_shape := getattr(src_ty, "shape", None)) is None:
412+
raise ValueError(f'Type {src_ty} does not have a type.')
413+
382414
if len(src_shape) < 2:
383415
raise NotImplementedError("Must use >1D values.")
384416

385-
tiling = _make_tiling(src_shape, src_dtype)
417+
tiling = _make_tiling(src_shape, src_ty)
386418
block_indices = self.compute_index(*grid_indices)
387419
return tuple(
388420
_make_block_slice(bi, bs, ss, t)
@@ -403,6 +435,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403435
"""Returns a new BufferedRefBase with the given block spec."""
404436
raise NotImplementedError()
405437

438+
def _ref_to_value_aval(ref):
439+
"""Return the inner of a ref, or a ShapedArray for TransformedRefs."""
440+
return (
441+
jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype)
442+
if isinstance(ref, state_types.TransformedRef)
443+
else jax.typeof(ref).inner_aval
444+
)
445+
406446

407447
# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408448
# instead of slot index.
@@ -413,7 +453,6 @@ class BufferedRef(BufferedRefBase):
413453
414454
Attributes:
415455
spec: pallas blockspec.
416-
dtype: dtype for buffers.
417456
buffer_type: enum indicating whether this is an input, output, or in/out
418457
accumulator buffered reference.
419458
window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +483,6 @@ class BufferedRef(BufferedRefBase):
444483
copy.
445484
"""
446485
_spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True))
447-
dtype: Any = dataclasses.field(metadata=dict(static=True))
448486
_buffer_type: BufferType = dataclasses.field(metadata=dict(static=True))
449487
window_ref: ArrayRef | None
450488
accum_ref: ArrayRef | None
@@ -507,7 +545,7 @@ def buffer_types() -> type[BufferType]:
507545
return BufferType
508546

509547
@classmethod
510-
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
548+
def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count,
511549
needs_swap_ref=True,
512550
grid_rank=None,
513551
use_lookahead=False,
@@ -516,7 +554,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516554
517555
Args:
518556
spec: pallas blockspec.
519-
dtype: dtype for buffers.
557+
dtype_or_type: dtype or aval for buffers. If an aval, the shape is
558+
ignored.
520559
buffer_type: enum indicating whether this is an input, output, or in/out
521560
accumulator buffered reference.
522561
needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +566,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527566
Returns:
528567
Initialized BufferedRef
529568
"""
569+
570+
# (123, 456) is a dummy shape since we never use ty without
571+
# calling .update(shape=...) first.
572+
ty = (
573+
dtype_or_type
574+
if isinstance(dtype_or_type, jax_core.AbstractValue)
575+
else jax_core.ShapedArray((123, 456), dtype_or_type)
576+
)
577+
530578
block_shape = _get_block_shape(spec)
531579
if buffer_type is BufferType.ACCUMULATOR:
532-
accum_ref = VMEM(block_shape, dtype)
580+
accum_ref = VMEM.from_type(ty.update(shape=block_shape))
533581
else:
534582
accum_ref = None
535583
if source_memory_space == VMEM:
@@ -541,7 +589,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541589
f"Cannot hold a non-buffered ref in {spec.memory_space=}")
542590
return cls(
543591
_spec=spec,
544-
dtype=dtype,
545592
_buffer_type=buffer_type,
546593
window_ref=None, # to be bound to existing ref by the pipeline routine
547594
accum_ref=accum_ref,
@@ -570,11 +617,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570617
raise ValueError(
571618
"grid_rank must be specified when use_lookahead is True."
572619
)
620+
621+
buffer_ty = ty.update(shape=(buffer_count, *block_shape))
573622
return cls(
574623
_spec=spec,
575-
dtype=dtype,
576624
_buffer_type=buffer_type,
577-
window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype),
625+
window_ref=buffer_memory_space.from_type(buffer_ty),
578626
accum_ref=accum_ref,
579627
copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
580628
wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
@@ -601,22 +649,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601649
)
602650

603651
@classmethod
604-
def input(cls, spec, dtype, buffer_count=2, **kwargs):
605-
return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs)
652+
def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
653+
return cls.create(
654+
spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs
655+
)
606656

607657
@classmethod
608-
def output(cls, spec, dtype, buffer_count=2, **kwargs):
609-
return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs)
658+
def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
659+
return cls.create(
660+
spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs
661+
)
610662

611663
@classmethod
612-
def accumulator(cls, spec, dtype, buffer_count=2, **kwargs):
613-
return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count,
614-
**kwargs)
664+
def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
665+
return cls.create(
666+
spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs
667+
)
615668

616669
@classmethod
617-
def input_output(cls, spec, dtype, buffer_count=2, **kwargs):
618-
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count,
619-
**kwargs)
670+
def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
671+
return cls.create(
672+
spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs
673+
)
620674

621675
@property
622676
def block_shape(self):
@@ -923,7 +977,7 @@ def copy_in(self, src_ref, grid_indices):
923977
if self.swap is not None:
924978
self.swap[0] = True
925979
slot = self.current_copy_in_slot
926-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
980+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
927981
dst_slice = tuple(
928982
pl.ds(0, s.size)
929983
for s, bd in zip(src_slice, self.block_shape)
@@ -944,7 +998,7 @@ def copy_out(self, dst_ref, grid_indices):
944998
if self.swap is not None:
945999
self.swap[0] = True
9461000
slot = self.current_copy_out_slot
947-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
1001+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
9481002
src_slice = tuple(
9491003
pl.ds(0, s.size)
9501004
for s, bd in zip(dst_slice, self.block_shape)
@@ -962,7 +1016,7 @@ def wait_in(self, src_ref, grid_indices):
9621016
if not self.is_buffered: return
9631017
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9641018
assert self.sem_recvs is not None
965-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
1019+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
9661020
dst_slice = tuple(
9671021
pl.ds(0, s.size)
9681022
for s, bd in zip(src_slice, self.block_shape)
@@ -984,7 +1038,7 @@ def wait_out(self, dst_ref, grid_indices):
9841038
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9851039
assert self.sem_sends is not None
9861040
wait_slot = self.current_wait_out_slot
987-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
1041+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
9881042
src_slice = tuple(
9891043
pl.ds(0, s.size)
9901044
for s, bd in zip(dst_slice, self.block_shape)
@@ -1682,7 +1736,9 @@ def make_input_bref(in_spec, in_ref):
16821736
use_lookahead = in_spec.pipeline_mode.use_lookahead
16831737
if use_lookahead and grid is None:
16841738
raise ValueError("Grid must be specified when using lookahead.")
1685-
return BufferedRef.input(in_spec, in_ref.dtype, buffer_count,
1739+
1740+
in_aval = _ref_to_value_aval(in_ref)
1741+
return BufferedRef.input(in_spec, in_aval, buffer_count,
16861742
needs_swap_ref=needs_swap_ref,
16871743
grid_rank=len(grid),
16881744
use_lookahead=use_lookahead,
@@ -1695,11 +1751,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951751
if out_spec.pipeline_mode.use_lookahead:
16961752
raise ValueError("Output buffering does not support lookahead.")
16971753

1754+
out_aval = _ref_to_value_aval(out_ref)
1755+
16981756
if accumulate:
1699-
return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count,
1757+
return BufferedRef.accumulator(out_spec, out_aval, buffer_count,
17001758
needs_swap_ref=needs_swap_ref,
17011759
source_memory_space=out_ref.memory_space)
1702-
return BufferedRef.output(out_spec, out_ref.dtype, buffer_count,
1760+
return BufferedRef.output(out_spec, out_aval, buffer_count,
17031761
needs_swap_ref=needs_swap_ref,
17041762
source_memory_space=out_ref.memory_space)
17051763
out_brefs = jax.tree.map(
@@ -1817,7 +1875,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171875
bref = dst
18181876
hbm_ref = src
18191877
copy_in = True
1820-
hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices)
1878+
hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices)
18211879
bref_slice = tuple(
18221880
pl.ds(0, s.size)
18231881
for s, bd in zip(hbm_slice, bref.block_shape)

0 commit comments

Comments
 (0)