Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue):
core_map_p = jax_core.Primitive("core_map")
core_map_p.multiple_results = True

def _core_map_is_high(*avals, jaxpr, **params):
del avals, params
return jaxpr.is_high
core_map_p.is_high = _core_map_is_high # type: ignore[method-assign]

def _core_map_to_lojax(*consts, jaxpr, mesh, **params):
closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
with (
tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()),
jax_core.extend_axis_env_nd(mesh.shape.items()),
):
closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
assert not closed_lo_jaxpr.is_high
return core_map_p.bind(
*closed_lo_jaxpr.consts,
jaxpr=closed_lo_jaxpr.jaxpr,
mesh=mesh,
**params,
)
core_map_p.to_lojax = _core_map_to_lojax


def core_map(
mesh,
Expand Down
138 changes: 96 additions & 42 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@
from typing import Any, Union

import jax
from jax import core as jax_core
from jax import lax
from jax import tree_util
from jax._src import util as jax_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives as primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import helpers as tpu_helpers
from jax._src.pallas.mosaic import tpu_info
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.pallas.mosaic import tpu_info
from jax._src.state import types as state_types
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np


SMEM = tpu_core.MemorySpace.SMEM
Expand Down Expand Up @@ -79,17 +80,32 @@ def add_leaves(i, x):
def _get_tpu_generation() -> int:
return tpu_info.get_tpu_info().generation

def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
# (2, 3, 128, 128) -> (1, 1, 8, 128).

def _make_tiling(
shape: tuple[int, ...], ty: jax_core.AbstractValue
) -> tuple[int | None, ...]:
"""Compute a tiling for the given shape and type.
For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
(2, 3, 128, 128) -> (1, 1, 8, 128).
Types are not required to have a dtype, so for such types we return None for
all dimensions because their tiling is unknown.
"""

if len(shape) < 2:
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")

if not hasattr(ty, 'dtype'):
return (None,) * len(shape)

leading_dims, final_dims = shape[:-2], shape[-2:]
# We want to find the minimum power of 2 that fits the second-minor dimension
# of shape, with maximum value 8.
second_minor, _ = final_dims
packing = 4 // dtype.itemsize

packing = 4 // ty.dtype.itemsize
max_tiling = _TILING[0]
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
while second_minor_tiling < min(second_minor, max_tiling):
Expand All @@ -114,13 +130,18 @@ def _make_block_ds(
assert isinstance(out, pl.Slice)
return out

def _create_blocked_slice(block_index: jax.Array | int,
block_size: int,
dim_size: int,
tiling: int):

def _create_blocked_slice(
block_index: jax.Array | int,
block_size: int,
dim_size: int,
tiling: int | None,
):
block_start = block_size * block_index
if (dim_rem := dim_size % block_size) == 0:
return pl.ds(block_start, block_size)
if tiling is None:
raise ValueError("If tiling is None, block_size must divide dim_size.")
if block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
num_blocks = pl.cdiv(dim_size, block_size)
Expand All @@ -137,12 +158,15 @@ def _create_bounded_slice(slice_start: jax.Array | int,
slice_size: jax.Array | int,
block_size: int,
dim_size: int,
tiling: int):
if block_size % tiling != 0:
tiling: int | None):
if tiling is not None and block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
# We assume by construction that slice_size <= block_size. We also assume
# that the slice_start is already aligned to the tiling.

if tiling is None:
return pl.ds(slice_start, slice_size)

# If we are out of bound, we need to round the slice size down to the nearest
# multiple of the tiling.
is_oob = slice_start + slice_size > dim_size
Expand All @@ -157,7 +181,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,

def _make_block_slice(
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
tiling: int
tiling: int | None
) -> pl.Slice | slice | int | jax.Array:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
Expand Down Expand Up @@ -332,7 +356,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
def compute_index(self):
return self.spec.index_map

def get_dma_slice(self, src_shape, src_dtype, grid_indices):
def get_dma_slice(self, src_ty, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
# block looks like this (for array shape (600, 600) and block shape
# (256, 256)):
Expand Down Expand Up @@ -379,10 +403,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
# for the last iteration on that dimension, we will pick the next highest
# tile multiple, i.e. (96, 256).

if (src_shape := getattr(src_ty, "shape", None)) is None:
raise ValueError(f'Type {src_ty} does not have a type.')

if len(src_shape) < 2:
raise NotImplementedError("Must use >1D values.")

tiling = _make_tiling(src_shape, src_dtype)
tiling = _make_tiling(src_shape, src_ty)
block_indices = self.compute_index(*grid_indices)
return tuple(
_make_block_slice(bi, bs, ss, t)
Expand All @@ -403,6 +431,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
"""Returns a new BufferedRefBase with the given block spec."""
raise NotImplementedError()

def _ref_to_value_aval(ref):
"""Return the inner of a ref, or a ShapedArray for TransformedRefs."""
return (
jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype)
if isinstance(ref, state_types.TransformedRef)
else jax.typeof(ref).inner_aval
)


# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
# instead of slot index.
Expand All @@ -413,7 +449,6 @@ class BufferedRef(BufferedRefBase):
Attributes:
spec: pallas blockspec.
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
window_ref: a multiple-buffer to hold the working and dirty buffers used
Expand Down Expand Up @@ -444,7 +479,6 @@ class BufferedRef(BufferedRefBase):
copy.
"""
_spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True))
dtype: Any = dataclasses.field(metadata=dict(static=True))
_buffer_type: BufferType = dataclasses.field(metadata=dict(static=True))
window_ref: ArrayRef | None
accum_ref: ArrayRef | None
Expand Down Expand Up @@ -507,7 +541,7 @@ def buffer_types() -> type[BufferType]:
return BufferType

@classmethod
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count,
needs_swap_ref=True,
grid_rank=None,
use_lookahead=False,
Expand All @@ -516,7 +550,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
Args:
spec: pallas blockspec.
dtype: dtype for buffers.
dtype_or_type: dtype or aval for buffers. If an aval, the shape is
ignored.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
Expand All @@ -527,9 +562,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
Returns:
Initialized BufferedRef
"""

# (123, 456) is a dummy shape since we never use ty without
# calling .update(shape=...) first.
ty = (
dtype_or_type
if isinstance(dtype_or_type, jax_core.AbstractValue)
else jax_core.ShapedArray((123, 456), dtype_or_type)
)

block_shape = _get_block_shape(spec)
if buffer_type is BufferType.ACCUMULATOR:
accum_ref = VMEM(block_shape, dtype)
accum_ref = VMEM.from_type(ty.update(shape=block_shape))
else:
accum_ref = None
if source_memory_space == VMEM:
Expand All @@ -541,7 +585,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
f"Cannot hold a non-buffered ref in {spec.memory_space=}")
return cls(
_spec=spec,
dtype=dtype,
_buffer_type=buffer_type,
window_ref=None, # to be bound to existing ref by the pipeline routine
accum_ref=accum_ref,
Expand Down Expand Up @@ -570,11 +613,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
raise ValueError(
"grid_rank must be specified when use_lookahead is True."
)

buffer_ty = ty.update(shape=(buffer_count, *block_shape))
return cls(
_spec=spec,
dtype=dtype,
_buffer_type=buffer_type,
window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype),
window_ref=buffer_memory_space.from_type(buffer_ty),
accum_ref=accum_ref,
copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
Expand All @@ -601,22 +645,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
)

@classmethod
def input(cls, spec, dtype, buffer_count=2, **kwargs):
return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs)
def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
return cls.create(
spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs
)

@classmethod
def output(cls, spec, dtype, buffer_count=2, **kwargs):
return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs)
def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
return cls.create(
spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs
)

@classmethod
def accumulator(cls, spec, dtype, buffer_count=2, **kwargs):
return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count,
**kwargs)
def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
return cls.create(
spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs
)

@classmethod
def input_output(cls, spec, dtype, buffer_count=2, **kwargs):
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count,
**kwargs)
def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
return cls.create(
spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs
)

@property
def block_shape(self):
Expand Down Expand Up @@ -923,7 +973,7 @@ def copy_in(self, src_ref, grid_indices):
if self.swap is not None:
self.swap[0] = True
slot = self.current_copy_in_slot
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
dst_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(src_slice, self.block_shape)
Expand All @@ -944,7 +994,7 @@ def copy_out(self, dst_ref, grid_indices):
if self.swap is not None:
self.swap[0] = True
slot = self.current_copy_out_slot
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
src_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(dst_slice, self.block_shape)
Expand All @@ -962,7 +1012,7 @@ def wait_in(self, src_ref, grid_indices):
if not self.is_buffered: return
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
assert self.sem_recvs is not None
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
dst_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(src_slice, self.block_shape)
Expand All @@ -984,7 +1034,7 @@ def wait_out(self, dst_ref, grid_indices):
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
assert self.sem_sends is not None
wait_slot = self.current_wait_out_slot
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
src_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(dst_slice, self.block_shape)
Expand Down Expand Up @@ -1682,7 +1732,9 @@ def make_input_bref(in_spec, in_ref):
use_lookahead = in_spec.pipeline_mode.use_lookahead
if use_lookahead and grid is None:
raise ValueError("Grid must be specified when using lookahead.")
return BufferedRef.input(in_spec, in_ref.dtype, buffer_count,

in_aval = _ref_to_value_aval(in_ref)
return BufferedRef.input(in_spec, in_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
grid_rank=len(grid),
use_lookahead=use_lookahead,
Expand All @@ -1695,11 +1747,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
if out_spec.pipeline_mode.use_lookahead:
raise ValueError("Output buffering does not support lookahead.")

out_aval = _ref_to_value_aval(out_ref)

if accumulate:
return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count,
return BufferedRef.accumulator(out_spec, out_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
source_memory_space=out_ref.memory_space)
return BufferedRef.output(out_spec, out_ref.dtype, buffer_count,
return BufferedRef.output(out_spec, out_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
source_memory_space=out_ref.memory_space)
out_brefs = jax.tree.map(
Expand Down Expand Up @@ -1817,7 +1871,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
bref = dst
hbm_ref = src
copy_in = True
hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices)
hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices)
bref_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(hbm_slice, bref.block_shape)
Expand Down
Loading
Loading