2323from typing import Any , Union
2424
2525import jax
26+ from jax import core as jax_core
2627from jax import lax
2728from jax import tree_util
2829from jax ._src import util as jax_util
2930from jax ._src .pallas import core as pallas_core
3031from jax ._src .pallas import primitives as primitives
3132from jax ._src .pallas .mosaic import core as tpu_core
3233from jax ._src .pallas .mosaic import helpers as tpu_helpers
33- from jax ._src .pallas .mosaic import tpu_info
3434from jax ._src .pallas .mosaic import primitives as tpu_primitives
35+ from jax ._src .pallas .mosaic import tpu_info
36+ from jax ._src .state import types as state_types
3537from jax .experimental import pallas as pl
3638import jax .numpy as jnp
37- import numpy as np
3839
3940
4041SMEM = tpu_core .MemorySpace .SMEM
@@ -79,17 +80,32 @@ def add_leaves(i, x):
7980def _get_tpu_generation () -> int :
8081 return tpu_info .get_tpu_info ().generation
8182
82- def _make_tiling (shape : tuple [int , ...], dtype : np .dtype ) -> tuple [int , ...]:
83- # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
84- # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
85- # (2, 3, 128, 128) -> (1, 1, 8, 128).
83+
84+ def _make_tiling (
85+ shape : tuple [int , ...], ty : jax_core .AbstractValue
86+ ) -> tuple [int | None , ...]:
87+ """Compute a tiling for the given shape and type.
88+
89+ For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
90+ and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
91+ (2, 3, 128, 128) -> (1, 1, 8, 128).
92+
93+ Types are not required to have a dtype, so for such types we return None for
94+ all dimensions because their tiling is unknown.
95+ """
96+
8697 if len (shape ) < 2 :
8798 raise ValueError (f"Shape must have at least 2 dimensions: { shape = } " )
99+
100+ if not hasattr (ty , 'dtype' ):
101+ return (None ,) * len (shape )
102+
88103 leading_dims , final_dims = shape [:- 2 ], shape [- 2 :]
89104 # We want to find the minimum power of 2 that fits the second-minor dimension
90105 # of shape, with maximum value 8.
91106 second_minor , _ = final_dims
92- packing = 4 // dtype .itemsize
107+
108+ packing = 4 // ty .dtype .itemsize
93109 max_tiling = _TILING [0 ]
94110 second_minor_tiling = (1 + int (_get_tpu_generation () < 4 )) * packing
95111 while second_minor_tiling < min (second_minor , max_tiling ):
@@ -114,17 +130,23 @@ def _make_block_ds(
114130 assert isinstance (out , pl .Slice )
115131 return out
116132
117- def _create_blocked_slice (block_index : jax .Array | int ,
118- block_size : int ,
119- dim_size : int ,
120- tiling : int ):
133+
134+ def _create_blocked_slice (
135+ block_index : jax .Array | int ,
136+ block_size : int ,
137+ dim_size : int ,
138+ tiling : int | None ,
139+ ):
121140 block_start = block_size * block_index
122141 if (dim_rem := dim_size % block_size ) == 0 :
123142 return pl .ds (block_start , block_size )
143+ if tiling is None :
144+ raise ValueError ("If tiling is None, block_size must divide dim_size." )
124145 if block_size % tiling != 0 :
125146 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
126147 num_blocks = pl .cdiv (dim_size , block_size )
127148 is_last = block_index == num_blocks - 1
149+
128150 rounded_size = jnp .where (
129151 is_last ,
130152 _round_up_to_nearest_multiple (dim_rem % block_size , tiling ),
@@ -133,20 +155,26 @@ def _create_blocked_slice(block_index: jax.Array | int,
133155 rounded_size = pl .multiple_of (rounded_size , tiling )
134156 return pl .ds (block_index * block_size , rounded_size )
135157
158+
136159def _create_bounded_slice (slice_start : jax .Array | int ,
137160 slice_size : jax .Array | int ,
138161 block_size : int ,
139162 dim_size : int ,
140- tiling : int ):
141- if block_size % tiling != 0 :
163+ tiling : int | None ):
164+ if tiling is not None and block_size % tiling != 0 :
142165 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
166+
143167 # We assume by construction that slice_size <= block_size. We also assume
144168 # that the slice_start is already aligned to the tiling.
145169
170+ if tiling is None :
171+ return pl .ds (slice_start , slice_size )
172+
146173 # If we are out of bound, we need to round the slice size down to the nearest
147174 # multiple of the tiling.
148175 is_oob = slice_start + slice_size > dim_size
149176 remaining = dim_size - slice_start
177+
150178 rounded_size = jnp .where (
151179 is_oob ,
152180 _round_up_to_nearest_multiple (remaining , tiling ),
@@ -157,7 +185,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,
157185
158186def _make_block_slice (
159187 block_index : jax .Array , block_size : pl .BlockDim | int | None , size : int ,
160- tiling : int
188+ tiling : int | None
161189) -> pl .Slice | slice | int | jax .Array :
162190 # Computes a slice given a block index and block size. In the default case,
163191 # we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +360,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332360 def compute_index (self ):
333361 return self .spec .index_map
334362
335- def get_dma_slice (self , src_shape , src_dtype , grid_indices ):
363+ def get_dma_slice (self , src_ty , grid_indices ):
336364 # We need to handle blocks that might go OOB in the src array. An in bounds
337365 # block looks like this (for array shape (600, 600) and block shape
338366 # (256, 256)):
@@ -379,10 +407,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379407 # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380408 # for the last iteration on that dimension, we will pick the next highest
381409 # tile multiple, i.e. (96, 256).
410+
411+ if (src_shape := getattr (src_ty , "shape" , None )) is None :
412+ raise ValueError (f'Type { src_ty } does not have a type.' )
413+
382414 if len (src_shape ) < 2 :
383415 raise NotImplementedError ("Must use >1D values." )
384416
385- tiling = _make_tiling (src_shape , src_dtype )
417+ tiling = _make_tiling (src_shape , src_ty )
386418 block_indices = self .compute_index (* grid_indices )
387419 return tuple (
388420 _make_block_slice (bi , bs , ss , t )
@@ -403,6 +435,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403435 """Returns a new BufferedRefBase with the given block spec."""
404436 raise NotImplementedError ()
405437
438+ def _ref_to_value_aval (ref ):
439+ """Return the inner of a ref, or a ShapedArray for TransformedRefs."""
440+ return (
441+ jax_core .ShapedArray (shape = ref .shape , dtype = ref .dtype )
442+ if isinstance (ref , state_types .TransformedRef )
443+ else jax .typeof (ref ).inner_aval
444+ )
445+
406446
407447# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408448# instead of slot index.
@@ -413,7 +453,6 @@ class BufferedRef(BufferedRefBase):
413453
414454 Attributes:
415455 spec: pallas blockspec.
416- dtype: dtype for buffers.
417456 buffer_type: enum indicating whether this is an input, output, or in/out
418457 accumulator buffered reference.
419458 window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +483,6 @@ class BufferedRef(BufferedRefBase):
444483 copy.
445484 """
446485 _spec : pl .BlockSpec = dataclasses .field (metadata = dict (static = True ))
447- dtype : Any = dataclasses .field (metadata = dict (static = True ))
448486 _buffer_type : BufferType = dataclasses .field (metadata = dict (static = True ))
449487 window_ref : ArrayRef | None
450488 accum_ref : ArrayRef | None
@@ -507,7 +545,7 @@ def buffer_types() -> type[BufferType]:
507545 return BufferType
508546
509547 @classmethod
510- def create (cls , spec : pl .BlockSpec , dtype , buffer_type , buffer_count ,
548+ def create (cls , spec : pl .BlockSpec , dtype_or_type , buffer_type , buffer_count ,
511549 needs_swap_ref = True ,
512550 grid_rank = None ,
513551 use_lookahead = False ,
@@ -516,7 +554,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516554
517555 Args:
518556 spec: pallas blockspec.
519- dtype: dtype for buffers.
557+ dtype_or_type: dtype or aval for buffers. If an aval, the shape is
558+ ignored.
520559 buffer_type: enum indicating whether this is an input, output, or in/out
521560 accumulator buffered reference.
522561 needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +566,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527566 Returns:
528567 Initialized BufferedRef
529568 """
569+
570+ # (123, 456) is a dummy shape since we never use ty without
571+ # calling .update(shape=...) first.
572+ ty = (
573+ dtype_or_type
574+ if isinstance (dtype_or_type , jax_core .AbstractValue )
575+ else jax_core .ShapedArray ((123 , 456 ), dtype_or_type )
576+ )
577+
530578 block_shape = _get_block_shape (spec )
531579 if buffer_type is BufferType .ACCUMULATOR :
532- accum_ref = VMEM ( block_shape , dtype )
580+ accum_ref = VMEM . from_type ( ty . update ( shape = block_shape ) )
533581 else :
534582 accum_ref = None
535583 if source_memory_space == VMEM :
@@ -541,7 +589,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541589 f"Cannot hold a non-buffered ref in { spec .memory_space = } " )
542590 return cls (
543591 _spec = spec ,
544- dtype = dtype ,
545592 _buffer_type = buffer_type ,
546593 window_ref = None , # to be bound to existing ref by the pipeline routine
547594 accum_ref = accum_ref ,
@@ -570,11 +617,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570617 raise ValueError (
571618 "grid_rank must be specified when use_lookahead is True."
572619 )
620+
621+ buffer_ty = ty .update (shape = (buffer_count , * block_shape ))
573622 return cls (
574623 _spec = spec ,
575- dtype = dtype ,
576624 _buffer_type = buffer_type ,
577- window_ref = buffer_memory_space (( buffer_count ,) + block_shape , dtype ),
625+ window_ref = buffer_memory_space . from_type ( buffer_ty ),
578626 accum_ref = accum_ref ,
579627 copy_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
580628 wait_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
@@ -601,22 +649,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601649 )
602650
603651 @classmethod
604- def input (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
605- return cls .create (spec , dtype , BufferType .INPUT , buffer_count , ** kwargs )
652+ def input (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
653+ return cls .create (
654+ spec , dtype_or_type , BufferType .INPUT , buffer_count , ** kwargs
655+ )
606656
607657 @classmethod
608- def output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
609- return cls .create (spec , dtype , BufferType .OUTPUT , buffer_count , ** kwargs )
658+ def output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
659+ return cls .create (
660+ spec , dtype_or_type , BufferType .OUTPUT , buffer_count , ** kwargs
661+ )
610662
611663 @classmethod
612- def accumulator (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
613- return cls .create (spec , dtype , BufferType .ACCUMULATOR , buffer_count ,
614- ** kwargs )
664+ def accumulator (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
665+ return cls .create (
666+ spec , dtype_or_type , BufferType .ACCUMULATOR , buffer_count , ** kwargs
667+ )
615668
616669 @classmethod
617- def input_output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
618- return cls .create (spec , dtype , BufferType .INPUT_OUTPUT , buffer_count ,
619- ** kwargs )
670+ def input_output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
671+ return cls .create (
672+ spec , dtype_or_type , BufferType .INPUT_OUTPUT , buffer_count , ** kwargs
673+ )
620674
621675 @property
622676 def block_shape (self ):
@@ -923,7 +977,7 @@ def copy_in(self, src_ref, grid_indices):
923977 if self .swap is not None :
924978 self .swap [0 ] = True
925979 slot = self .current_copy_in_slot
926- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
980+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
927981 dst_slice = tuple (
928982 pl .ds (0 , s .size )
929983 for s , bd in zip (src_slice , self .block_shape )
@@ -944,7 +998,7 @@ def copy_out(self, dst_ref, grid_indices):
944998 if self .swap is not None :
945999 self .swap [0 ] = True
9461000 slot = self .current_copy_out_slot
947- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
1001+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
9481002 src_slice = tuple (
9491003 pl .ds (0 , s .size )
9501004 for s , bd in zip (dst_slice , self .block_shape )
@@ -962,7 +1016,7 @@ def wait_in(self, src_ref, grid_indices):
9621016 if not self .is_buffered : return
9631017 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9641018 assert self .sem_recvs is not None
965- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
1019+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
9661020 dst_slice = tuple (
9671021 pl .ds (0 , s .size )
9681022 for s , bd in zip (src_slice , self .block_shape )
@@ -984,7 +1038,7 @@ def wait_out(self, dst_ref, grid_indices):
9841038 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9851039 assert self .sem_sends is not None
9861040 wait_slot = self .current_wait_out_slot
987- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
1041+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
9881042 src_slice = tuple (
9891043 pl .ds (0 , s .size )
9901044 for s , bd in zip (dst_slice , self .block_shape )
@@ -1682,7 +1736,9 @@ def make_input_bref(in_spec, in_ref):
16821736 use_lookahead = in_spec .pipeline_mode .use_lookahead
16831737 if use_lookahead and grid is None :
16841738 raise ValueError ("Grid must be specified when using lookahead." )
1685- return BufferedRef .input (in_spec , in_ref .dtype , buffer_count ,
1739+
1740+ in_aval = _ref_to_value_aval (in_ref )
1741+ return BufferedRef .input (in_spec , in_aval , buffer_count ,
16861742 needs_swap_ref = needs_swap_ref ,
16871743 grid_rank = len (grid ),
16881744 use_lookahead = use_lookahead ,
@@ -1695,11 +1751,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951751 if out_spec .pipeline_mode .use_lookahead :
16961752 raise ValueError ("Output buffering does not support lookahead." )
16971753
1754+ out_aval = _ref_to_value_aval (out_ref )
1755+
16981756 if accumulate :
1699- return BufferedRef .accumulator (out_spec , out_ref . dtype , buffer_count ,
1757+ return BufferedRef .accumulator (out_spec , out_aval , buffer_count ,
17001758 needs_swap_ref = needs_swap_ref ,
17011759 source_memory_space = out_ref .memory_space )
1702- return BufferedRef .output (out_spec , out_ref . dtype , buffer_count ,
1760+ return BufferedRef .output (out_spec , out_aval , buffer_count ,
17031761 needs_swap_ref = needs_swap_ref ,
17041762 source_memory_space = out_ref .memory_space )
17051763 out_brefs = jax .tree .map (
@@ -1817,7 +1875,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171875 bref = dst
18181876 hbm_ref = src
18191877 copy_in = True
1820- hbm_slice = bref .get_dma_slice (hbm_ref . shape , hbm_ref . dtype , indices )
1878+ hbm_slice = bref .get_dma_slice (_ref_to_value_aval ( hbm_ref ) , indices )
18211879 bref_slice = tuple (
18221880 pl .ds (0 , s .size )
18231881 for s , bd in zip (hbm_slice , bref .block_shape )
0 commit comments