Commit fa4980c
[Mosaic GPU] Change row-warp assignment logic in matmul example epilogue.
Previously we were assigning rows in a round-robin fashion. Now, contiguous
rows are assigned to the same warp for up to
```
vector_len * lanes_per_warp / min(n_out_tiling) = 4 * 32 / 32 = 4 rows.
```
This could theoretically help with small tile sizes, but in practice it
doesn't seem to make a difference. Benchmarking with parameters `lhs_dtype=jnp.float32`, `rhs_dtype=jnp.float32`, `tile_m=128`, `rhs_transpose=True`, `stages=2`, and varying values for `tile_n`, gives us the following results.
Before:
```
tile_n=32: 94.9 us = 93.4 TFLOPS
tile_n=64: 74.2 us = 119.4 TFLOPS
tile_n=128: 73.1 us = 121.3 TFLOPS
```
After:
```
tile_n=32: 96.1 us = 92.2 TFLOPS
tile_n=64: 71.9 us = 123.1 TFLOPS
tile_n=128: 73.1 us = 121.1 TFLOPS
```
PiperOrigin-RevId: 6383194801 parent cc0a20f commit fa4980c
1 file changed
+27
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
389 | 389 | | |
390 | 390 | | |
391 | 391 | | |
| 392 | + | |
392 | 393 | | |
393 | 394 | | |
394 | 395 | | |
395 | 396 | | |
396 | 397 | | |
397 | 398 | | |
398 | | - | |
399 | | - | |
400 | | - | |
401 | | - | |
402 | | - | |
403 | | - | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
404 | 421 | | |
405 | 422 | | |
406 | 423 | | |
407 | | - | |
408 | | - | |
409 | | - | |
410 | | - | |
411 | | - | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
412 | 427 | | |
413 | 428 | | |
414 | 429 | | |
| |||
418 | 433 | | |
419 | 434 | | |
420 | 435 | | |
421 | | - | |
| 436 | + | |
422 | 437 | | |
423 | 438 | | |
424 | 439 | | |
| |||
0 commit comments