Skip to content

Commit fa4980c

Browse files
bchetiouijax authors
authored andcommitted
[Mosaic GPU] Change row-warp assignment logic in matmul example epilogue.
Previously we were assigning rows in a round-robin fashion. Now, contiguous rows are assigned to the same warp for up to ``` vector_len * lanes_per_warp / min(n_out_tiling) = 4 * 32 / 32 = 4 rows. ``` This could theoretically help with small tile sizes, but in practice it doesn't seem to make a difference. Benchmarking with parameters `lhs_dtype=jnp.float32`, `rhs_dtype=jnp.float32`, `tile_m=128`, `rhs_transpose=True`, `stages=2`, and varying values for `tile_n`, gives us the following results. Before: ``` tile_n=32: 94.9 us = 93.4 TFLOPS tile_n=64: 74.2 us = 119.4 TFLOPS tile_n=128: 73.1 us = 121.3 TFLOPS ``` After: ``` tile_n=32: 96.1 us = 92.2 TFLOPS tile_n=64: 71.9 us = 123.1 TFLOPS tile_n=128: 73.1 us = 121.1 TFLOPS ``` PiperOrigin-RevId: 638319480
1 parent cc0a20f commit fa4980c

File tree

1 file changed

+27
-12
lines changed
  • jax/experimental/mosaic/gpu/examples

1 file changed

+27
-12
lines changed

jax/experimental/mosaic/gpu/examples/matmul.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -389,26 +389,41 @@ def stage_loop_body(ki, accs):
389389
# TODO(apaszke): Make this into a proper copy function.
390390
warps_per_warpgroup = 4
391391
lanes_per_warp = 32
392+
m_out_tiling = out_tiling[-2]
392393
n_out_tiling = out_tiling[-1]
393394
tidx = gpu.thread_id(gpu.Dimension.x)
394395
warp_id = arith.divui(tidx, c(lanes_per_warp))
395396
lane_id = arith.remui(tidx, c(lanes_per_warp))
396397
# We store 4 f32 numbers for a block of 16B.
397398
vector_len = 4
398-
num_vectors = safe_div(tile_n, vector_len)
399-
for_op = scf.ForOp(warp_id, c(tile_m), c(warps_per_warpgroup))
400-
with ir.InsertionPoint(for_op.body):
401-
nested_for_op = scf.ForOp(lane_id, c(num_vectors), c(lanes_per_warp))
402-
with ir.InsertionPoint(nested_for_op.body):
403-
vector_idx = nested_for_op.induction_variable
399+
num_vectors_per_row = safe_div(tile_n, vector_len)
400+
# Process several rows at once if it is necessary to fully exploit each
401+
# warp.
402+
if tile_n < lanes_per_warp * vector_len:
403+
num_rows_per_warp = min(
404+
safe_div(lanes_per_warp * vector_len, tile_n),
405+
safe_div(tile_m, warps_per_warpgroup))
406+
else:
407+
num_rows_per_warp = 1
408+
lanes_per_row = safe_div(lanes_per_warp, num_rows_per_warp)
409+
lane_row_offset = arith.divui(lane_id, c(lanes_per_row))
410+
lane_col_offset = arith.remui(lane_id, c(lanes_per_row))
411+
warp_for_op = scf.ForOp(arith.muli(warp_id, c(num_rows_per_warp)),
412+
c(tile_m),
413+
c(warps_per_warpgroup * num_rows_per_warp))
414+
with ir.InsertionPoint(warp_for_op.body):
415+
start_row = warp_for_op.induction_variable
416+
m_row_idx = arith.addi(start_row, lane_row_offset)
417+
vector_for_op = scf.ForOp(lane_col_offset, c(num_vectors_per_row),
418+
c(lanes_per_row))
419+
with ir.InsertionPoint(vector_for_op.body):
420+
vector_idx = vector_for_op.induction_variable
404421
n_store = arith.muli(vector_idx, c(vector_len))
405422
col_group = arith.divui(n_store, c(n_out_tiling))
406423
n_load = arith.remui(n_store, c(n_out_tiling))
407-
408-
m_smem = for_op.induction_variable
409-
m_within_tile = arith.remui(m_smem, c(64))
410-
m_tile = arith.divui(m_smem, c(64))
411-
swizzle_source = arith.shli(arith.remui(m_smem, c(8)), c(2))
424+
m_within_tile = arith.remui(m_row_idx, c(m_out_tiling))
425+
m_tile = arith.divui(m_row_idx, c(m_out_tiling))
426+
swizzle_source = arith.shli(arith.remui(m_row_idx, c(8)), c(2))
412427
n_acc = arith.xori(n_load, swizzle_source)
413428
acc_part = vector.load(
414429
ir.VectorType.get((vector_len,), f32),
@@ -418,7 +433,7 @@ def stage_loop_body(ki, accs):
418433
vector.store(
419434
acc_part,
420435
c_device,
421-
[arith.addi(m_start, m_smem), arith.addi(n_start, n_store)],
436+
[arith.addi(m_start, m_row_idx), arith.addi(n_start, n_store)],
422437
)
423438
scf.yield_([])
424439
scf.yield_([])

0 commit comments

Comments
 (0)