Skip to content

Commit 50bdc72

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Slightly tweaked the error messages in a few places
PiperOrigin-RevId: 842272894
1 parent 2f62fb1 commit 50bdc72

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,8 @@ def untransform_reshape(
670670
self, dtype: jnp.dtype, shape: tuple[int, ...]
671671
) -> tuple[tuple[int, ...], state_types.Transform]:
672672
del dtype
673-
raise NotImplementedError("Reshapes don't commute with transposes.")
673+
# TODO(slebedev): Support this.
674+
raise NotImplementedError("Reshapes don't commute with tiling.")
674675

675676
def untransform_index(
676677
self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...]

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
862862
if max(slice_shape) > 256:
863863
raise ValueError(
864864
"Async copies only support copying <=256 elements along each"
865-
" dimension"
865+
f" dimension, got {tuple(slice_shape)}"
866866
)
867867
if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0:
868868
raise ValueError(
@@ -1019,7 +1019,7 @@ def async_copy(
10191019
raise ValueError(
10201020
"Expected the SMEM reference to have the same shape as the"
10211021
f" transformed slice: {tuple(smem_ref_ty.shape)} !="
1022-
f" {slice_shape[len(squeezed_dims):]}"
1022+
f" {tuple(slice_shape[len(squeezed_dims):])}"
10231023
)
10241024

10251025
if implementation == AsyncCopyImplementation.CP_ASYNC:

jax/experimental/mosaic/gpu/wgmma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def value(self) -> fa.FragmentedArray:
6565
@classmethod
6666
def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
6767
if m % 64 or n % 8:
68-
raise ValueError
68+
raise ValueError("WGMMA requires m and n to be multiples of 64 and 8, "
69+
f"got {m} and {n}")
6970
if is_signed is False: # pylint: disable=g-bool-id-comparison
7071
raise TypeError("PTX does not support unsigned WGMMA accumulators")
7172
f32 = ir.F32Type.get()

0 commit comments

Comments
 (0)