diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 32ceb1c21dc2..6c995002976b 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1365,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue): core_map_p = jax_core.Primitive("core_map") core_map_p.multiple_results = True +def _core_map_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +core_map_p.is_high = _core_map_is_high # type: ignore[method-assign] + +def _core_map_to_lojax(*consts, jaxpr, mesh, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + assert not closed_lo_jaxpr.is_high + return core_map_p.bind( + *closed_lo_jaxpr.consts, + jaxpr=closed_lo_jaxpr.jaxpr, + mesh=mesh, + **params, + ) +core_map_p.to_lojax = _core_map_to_lojax + def core_map( mesh, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 9bd15e18f17d..e519a323ca20 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -23,6 +23,7 @@ from typing import Any, Union import jax +from jax import core as jax_core from jax import lax from jax import tree_util from jax._src import util as jax_util @@ -30,11 +31,11 @@ from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import helpers as tpu_helpers -from jax._src.pallas.mosaic import tpu_info from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import tpu_info +from jax._src.state import types as state_types from jax.experimental import pallas as pl import jax.numpy as jnp -import numpy as np SMEM = tpu_core.MemorySpace.SMEM @@ -79,17 +80,32 @@ def add_leaves(i, x): def _get_tpu_generation() -> int: return tpu_info.get_tpu_info().generation -def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: - # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions - # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and - # (2, 3, 128, 128) -> (1, 1, 8, 128). + +def _make_tiling( + shape: tuple[int, ...], ty: jax_core.AbstractValue +) -> tuple[int | None, ...]: + """Compute a tiling for the given shape and type. + + For a n-dimensional shape, returns (8, 128) for the last 2 dimensions + and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and + (2, 3, 128, 128) -> (1, 1, 8, 128). + + Types are not required to have a dtype, so for such types we return None for + all dimensions because their tiling is unknown. + """ + if len(shape) < 2: raise ValueError(f"Shape must have at least 2 dimensions: {shape=}") + + if not hasattr(ty, 'dtype'): + return (None,) * len(shape) + leading_dims, final_dims = shape[:-2], shape[-2:] # We want to find the minimum power of 2 that fits the second-minor dimension # of shape, with maximum value 8. second_minor, _ = final_dims - packing = 4 // dtype.itemsize + + packing = 4 // ty.dtype.itemsize max_tiling = _TILING[0] second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing while second_minor_tiling < min(second_minor, max_tiling): @@ -114,13 +130,18 @@ def _make_block_ds( assert isinstance(out, pl.Slice) return out -def _create_blocked_slice(block_index: jax.Array | int, - block_size: int, - dim_size: int, - tiling: int): + +def _create_blocked_slice( + block_index: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int | None, +): block_start = block_size * block_index if (dim_rem := dim_size % block_size) == 0: return pl.ds(block_start, block_size) + if tiling is None: + raise ValueError("If tiling is None, block_size must divide dim_size.") if block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") num_blocks = pl.cdiv(dim_size, block_size) @@ -137,12 +158,15 @@ def _create_bounded_slice(slice_start: jax.Array | int, slice_size: jax.Array | int, block_size: int, dim_size: int, - tiling: int): - if block_size % tiling != 0: + tiling: int | None): + if tiling is not None and block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") # We assume by construction that slice_size <= block_size. We also assume # that the slice_start is already aligned to the tiling. + if tiling is None: + return pl.ds(slice_start, slice_size) + # If we are out of bound, we need to round the slice size down to the nearest # multiple of the tiling. is_oob = slice_start + slice_size > dim_size @@ -157,7 +181,7 @@ def _create_bounded_slice(slice_start: jax.Array | int, def _make_block_slice( block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, - tiling: int + tiling: int | None ) -> pl.Slice | slice | int | jax.Array: # Computes a slice given a block index and block size. In the default case, # we return slice(block_index * block_size, (block_index + 1) * block_size). @@ -332,7 +356,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: def compute_index(self): return self.spec.index_map - def get_dma_slice(self, src_shape, src_dtype, grid_indices): + def get_dma_slice(self, src_ty, grid_indices): # We need to handle blocks that might go OOB in the src array. An in bounds # block looks like this (for array shape (600, 600) and block shape # (256, 256)): @@ -379,10 +403,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices): # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block # for the last iteration on that dimension, we will pick the next highest # tile multiple, i.e. (96, 256). + + if (src_shape := getattr(src_ty, "shape", None)) is None: + raise ValueError(f'Type {src_ty} does not have a type.') + if len(src_shape) < 2: raise NotImplementedError("Must use >1D values.") - tiling = _make_tiling(src_shape, src_dtype) + tiling = _make_tiling(src_shape, src_ty) block_indices = self.compute_index(*grid_indices) return tuple( _make_block_slice(bi, bs, ss, t) @@ -403,6 +431,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: """Returns a new BufferedRefBase with the given block spec.""" raise NotImplementedError() +def _ref_to_value_aval(ref): + """Return the inner of a ref, or a ShapedArray for TransformedRefs.""" + return ( + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype) + if isinstance(ref, state_types.TransformedRef) + else jax.typeof(ref).inner_aval + ) + # TODO(justinfu): Refactor and rename slot fields to reflect cumulative values # instead of slot index. @@ -413,7 +449,6 @@ class BufferedRef(BufferedRefBase): Attributes: spec: pallas blockspec. - dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. window_ref: a multiple-buffer to hold the working and dirty buffers used @@ -444,7 +479,6 @@ class BufferedRef(BufferedRefBase): copy. """ _spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True)) - dtype: Any = dataclasses.field(metadata=dict(static=True)) _buffer_type: BufferType = dataclasses.field(metadata=dict(static=True)) window_ref: ArrayRef | None accum_ref: ArrayRef | None @@ -507,7 +541,7 @@ def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, + def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count, needs_swap_ref=True, grid_rank=None, use_lookahead=False, @@ -516,7 +550,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, Args: spec: pallas blockspec. - dtype: dtype for buffers. + dtype_or_type: dtype or aval for buffers. If an aval, the shape is + ignored. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. needs_swap_ref: whether a swap slots tracker needs to be allocated. @@ -527,9 +562,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, Returns: Initialized BufferedRef """ + + # (123, 456) is a dummy shape since we never use ty without + # calling .update(shape=...) first. + ty = ( + dtype_or_type + if isinstance(dtype_or_type, jax_core.AbstractValue) + else jax_core.ShapedArray((123, 456), dtype_or_type) + ) + block_shape = _get_block_shape(spec) if buffer_type is BufferType.ACCUMULATOR: - accum_ref = VMEM(block_shape, dtype) + accum_ref = VMEM.from_type(ty.update(shape=block_shape)) else: accum_ref = None if source_memory_space == VMEM: @@ -541,7 +585,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, f"Cannot hold a non-buffered ref in {spec.memory_space=}") return cls( _spec=spec, - dtype=dtype, _buffer_type=buffer_type, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, @@ -570,11 +613,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, raise ValueError( "grid_rank must be specified when use_lookahead is True." ) + + buffer_ty = ty.update(shape=(buffer_count, *block_shape)) return cls( _spec=spec, - dtype=dtype, _buffer_type=buffer_type, - window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype), + window_ref=buffer_memory_space.from_type(buffer_ty), accum_ref=accum_ref, copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, @@ -601,22 +645,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, ) @classmethod - def input(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs) + def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs + ) @classmethod - def output(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs) + def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs + ) @classmethod - def accumulator(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count, - **kwargs) + def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs + ) @classmethod - def input_output(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count, - **kwargs) + def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs + ) @property def block_shape(self): @@ -923,7 +973,7 @@ def copy_in(self, src_ref, grid_indices): if self.swap is not None: self.swap[0] = True slot = self.current_copy_in_slot - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) dst_slice = tuple( pl.ds(0, s.size) for s, bd in zip(src_slice, self.block_shape) @@ -944,7 +994,7 @@ def copy_out(self, dst_ref, grid_indices): if self.swap is not None: self.swap[0] = True slot = self.current_copy_out_slot - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) src_slice = tuple( pl.ds(0, s.size) for s, bd in zip(dst_slice, self.block_shape) @@ -962,7 +1012,7 @@ def wait_in(self, src_ref, grid_indices): if not self.is_buffered: return assert not (self.window_ref is None or isinstance(self.window_ref, REF)) assert self.sem_recvs is not None - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) dst_slice = tuple( pl.ds(0, s.size) for s, bd in zip(src_slice, self.block_shape) @@ -984,7 +1034,7 @@ def wait_out(self, dst_ref, grid_indices): assert not (self.window_ref is None or isinstance(self.window_ref, REF)) assert self.sem_sends is not None wait_slot = self.current_wait_out_slot - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) src_slice = tuple( pl.ds(0, s.size) for s, bd in zip(dst_slice, self.block_shape) @@ -1682,7 +1732,9 @@ def make_input_bref(in_spec, in_ref): use_lookahead = in_spec.pipeline_mode.use_lookahead if use_lookahead and grid is None: raise ValueError("Grid must be specified when using lookahead.") - return BufferedRef.input(in_spec, in_ref.dtype, buffer_count, + + in_aval = _ref_to_value_aval(in_ref) + return BufferedRef.input(in_spec, in_aval, buffer_count, needs_swap_ref=needs_swap_ref, grid_rank=len(grid), use_lookahead=use_lookahead, @@ -1695,11 +1747,13 @@ def make_output_bref(out_spec, out_ref, accumulate): if out_spec.pipeline_mode.use_lookahead: raise ValueError("Output buffering does not support lookahead.") + out_aval = _ref_to_value_aval(out_ref) + if accumulate: - return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count, + return BufferedRef.accumulator(out_spec, out_aval, buffer_count, needs_swap_ref=needs_swap_ref, source_memory_space=out_ref.memory_space) - return BufferedRef.output(out_spec, out_ref.dtype, buffer_count, + return BufferedRef.output(out_spec, out_aval, buffer_count, needs_swap_ref=needs_swap_ref, source_memory_space=out_ref.memory_space) out_brefs = jax.tree.map( @@ -1817,7 +1871,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): bref = dst hbm_ref = src copy_in = True - hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices) + hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices) bref_slice = tuple( pl.ds(0, s.size) for s, bd in zip(hbm_slice, bref.block_shape) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index e472253fb231..47a107368f96 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -383,6 +383,52 @@ def _get_dma_effects( dma_start_p = jax_core.Primitive('dma_start') dma_start_p.multiple_results = True +def _dma_is_high(*avals, **params): + return any(aval.is_high for aval in avals) + +dma_start_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_start_to_lojax(*args, tree, device_id_type, priority, add): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + + src_ref_aval.inner_aval.dma_start( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + priority=priority, + device_id_type=device_id_type, + add=add + ) + return [] +dma_start_p.to_lojax = _dma_start_to_lojax + @dma_start_p.def_effectful_abstract_eval def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add): if priority < 0: @@ -646,6 +692,46 @@ def do_discharge_src_sem(src_sem=src_sem): dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True +dma_wait_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_wait_to_lojax(*args, tree, device_id_type): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + src_ref_aval.inner_aval.dma_wait( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + return [] +dma_wait_p.to_lojax = _dma_wait_to_lojax + @dma_wait_p.def_effectful_abstract_eval def _dma_wait_abstract_eval(*args, tree, device_id_type): del device_id_type diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index b612c139acda..4ae79d0769e6 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -854,6 +854,17 @@ def wrap_with_transforms(f, transforms, *args): run_scoped_p = jax_core.Primitive("run_scoped") run_scoped_p.multiple_results = True +def _run_scoped_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +run_scoped_p.is_high = _run_scoped_is_high # type: ignore[method-assign] + +def _run_scoped_to_lojax(*args, jaxpr, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, args) + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + consts = closed_lo_jaxpr.consts + return run_scoped_p.bind(*consts, jaxpr=closed_lo_jaxpr.jaxpr, **params) +run_scoped_p.to_lojax = _run_scoped_to_lojax def run_scoped( f: Callable[..., Any], diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 2644f8392416..e9f589a10f27 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -409,7 +409,10 @@ def is_high(self): return self.inner_aval.is_high def lo_ty(self): - return map(AbstractRef, self.inner_aval.lo_ty()) + return [ + AbstractRef(x, memory_space=self.memory_space) + for x in self.inner_aval.lo_ty() + ] def lower_val(self, ref): if not self.is_high: diff --git a/jax/experimental/hijax.py b/jax/experimental/hijax.py index 5e5bb0512c79..087569ae9234 100644 --- a/jax/experimental/hijax.py +++ b/jax/experimental/hijax.py @@ -36,4 +36,5 @@ ) from jax._src.state import ( AbstractRef as AbstractRef, + TransformedRef as TransformedRef ) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 3bfde20840a3..26780d42252c 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -475,6 +475,15 @@ MemRefType getMemRefType(Value value) { return cast(value.getType()); } +template +bool checkBothOperandsDivisible(Value value, int64_t divisor, int64_t fuel) { + if (auto op = value.getDefiningOp()) { + return isGuaranteedDivisible(op.getLhs(), divisor, fuel / 2) && + isGuaranteedDivisible(op.getRhs(), divisor, (fuel + 1) / 2); + } + return false; +} + bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; @@ -497,9 +506,16 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (auto cast_op = value.getDefiningOp()) { return isGuaranteedDivisible(cast_op.getOperand(), divisor, fuel - 1); } - if (auto add_op = value.getDefiningOp()) { - return isGuaranteedDivisible(add_op.getRhs(), divisor, fuel / 2) && - isGuaranteedDivisible(add_op.getLhs(), divisor, (fuel + 1) / 2); + if (checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel)) { + return true; + } + if (auto select_op = value.getDefiningOp()) { + return isGuaranteedDivisible(select_op.getTrueValue(), divisor, fuel / 2) && + isGuaranteedDivisible(select_op.getFalseValue(), divisor, + (fuel + 1) / 2); } return false; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 16f710836f0b..9d9fcf624a40 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -77,7 +77,7 @@ LogicalResult specializeMemorySpace(TypedValue value, // vector ops. This functions inverts the layout erasure applied to the value. MemRefType getMemRefType(Value value); -bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8); +bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 128); DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, bool transpose_lhs, diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 4128e6dffda2..0d03cc9b5444 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -538,6 +538,7 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ + "//jax/experimental:hijax", "//jax/experimental:mesh_utils", "//jax/experimental:pallas_tpu", "//jax/experimental:pallas_tpu_ops", diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 42f7674f7adc..257b1a474cb6 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -14,6 +14,7 @@ """Test TPU-specific extensions to pallas_call.""" +import dataclasses import functools from absl.testing import absltest from absl.testing import parameterized @@ -21,10 +22,14 @@ import hypothesis.strategies as hps import jax from jax import lax +from jax._src import hijax +from jax._src import shard_map +from jax._src import state from jax._src import test_util as jtu +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -2189,5 +2194,253 @@ def f(x, slices): np.testing.assert_allclose(out, x) +class PipelineHijaxTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+.') + + def test_emit_pipeline_hijax(self): + @dataclasses.dataclass(frozen=True) + class ArrayTuple: + x0: jax.Array + x1: jax.Array + + @property + def shape(self): + assert self.x0.shape == self.x1.shape + return self.x0.shape + + @property + def dtype(self): + assert self.x0.dtype == self.x1.dtype + return self.x0.dtype + + @dataclasses.dataclass(frozen=True) + class ShapedArrayTuple(hijax.HiType): + shape: tuple[int, ...] + dtype: jnp.dtype + + update = dataclasses.replace + + def lo_ty(self) -> list[hijax.ShapedArray]: + return [hijax.ShapedArray(self.shape, self.dtype)] * 2 + + def lower_val(self, hi_val: ArrayTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + + def raise_val(self, x0, x1) -> ArrayTuple: + return ArrayTuple(x0, x1) + + def ref_get_abstract_eval(self, ref_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out, effects = state_primitives.get_p.abstract_eval( + updated_ref, *args, tree=tree + ) + assert isinstance(out, hijax.ShapedArray) + return ShapedArrayTuple(out.shape, out.dtype), effects + + def ref_get_to_lojax( + self, ref: state.TransformedRef | jax.Ref, idx: indexing.NDIndexer + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.get_p.bind( + tup_ref.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.get_p.bind( + tup_ref.x1, *flat_transforms, tree=tree + ) + return ShapedArrayTuple(x0_out, x1_out).raise_val(x0_out, x1_out) + + def ref_swap_abstract_eval(self, ref_aval, val_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + val_arr_aval = hijax.ShapedArray(val_aval.shape, val_aval.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out_aval, effects = state_primitives.swap_p.abstract_eval( + updated_ref, val_arr_aval, *args, tree=tree + ) + assert isinstance(out_aval, hijax.ShapedArray) + return ShapedArrayTuple(out_aval.shape, out_aval.dtype), effects + + def ref_swap_to_lojax( + self, + ref: state.TransformedRef | jax.Ref, + val: ArrayTuple, + idx: indexing.NDIndexer, + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.swap_p.bind( + tup_ref.x0, val.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.swap_p.bind( + tup_ref.x1, val.x1, *flat_transforms, tree=tree + ) + return self.raise_val(x0_out, x1_out) + + def lower_block_spec( + self, block_spec: pl.BlockSpec + ) -> list[pl.BlockSpec]: + return [block_spec, block_spec] + + def dma_start( + self, + src_ref: state.TransformedRef, + dst_ref: state.TransformedRef, + src_sem: state.TransformedRef, + dst_sem: state.TransformedRef, + device_id: jax.Array | int | None, + device_id_type: pl.DeviceIdType, + priority: int, + add: bool, + ) -> None: + del add + src_aval = jax.typeof(src_ref.ref).inner_aval + assert isinstance(src_aval, ShapedArrayTuple) + dst_aval = jax.typeof(dst_ref.ref).inner_aval + assert isinstance(dst_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.start(priority=priority) + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + + def dma_wait( + self, src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + assert isinstance(jax.typeof(src_ref.ref).inner_aval, ShapedArrayTuple) + assert isinstance(jax.typeof(dst_ref.ref).inner_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.wait() + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + + hijax.register_hitype( + ArrayTuple, lambda q: ShapedArrayTuple(q.shape, q.dtype) + ) + + def kernel(x_hbm_ref, o_hbm_ref): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + num_steps = 4 + block_shape = (x_hbm_ref.shape[0] // num_steps, x_hbm_ref.shape[1]) + + pltpu.emit_pipeline( + body, + grid=(num_steps,), + in_specs=(pl.BlockSpec(block_shape, lambda i: (i, 0)),), + out_specs=pl.BlockSpec(block_shape, lambda i: (i, 0)), + )(x_hbm_ref, o_hbm_ref) + + inp = ArrayTuple( + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + ) + + out_ty = ShapedArrayTuple( + inp.shape, + inp.dtype, + ) + + out = pl.pallas_call( + kernel, + in_specs=(pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),), + out_shape=out_ty, + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + )(inp) + + np.testing.assert_allclose(out.x0, inp.x0) + np.testing.assert_allclose(out.x1, inp.x1) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())