-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Problem statement
The BLAS level-3 gemm kernel has no Triton implementation in kernel_course.triton_ops. The README BLAS table shows an empty Triton column for gemm, even though GEMM is the canonical example for high-performance GPU kernels.
Without a Triton gemm kernel:
- users cannot study or experiment with GEMM implementations in Triton within this project,
- there is no Triton baseline for performance comparison against PyTorch and CuTe GEMM,
- the Triton backend lacks a flagship example that ties together tiling, memory layouts, and parallelism.
Proposed solution
Add a Triton implementation of gemm under kernel_course.triton_ops that matches the Python/PyTorch semantics and showcases best practices for GEMM-style kernels.
Concretely:
- Introduce
kernel_course/triton_ops/gemm.pydefining a Triton JIT GEMM kernel and Python wrapper. - Implement tiled matrix multiplication with configurable block sizes and good memory access patterns.
- Ensure numerical equivalence with the Python reference across supported dtypes.
Alternatives considered
Relying solely on PyTorch or CuTe for GEMM would:
- miss a key educational opportunity to present GEMM in Triton,
- limit performance and implementation comparisons across backends,
- leave the Triton column incomplete in the README BLAS table.
Implementation details
- Add
kernel_course/triton_ops/gemm.pycontaining:- a Triton kernel with parameters for tiles of
A,B, andC, - a wrapper to configure grid/block sizes and launch the kernel.
- a Triton kernel with parameters for tiles of
- Handle arbitrary matrix shapes with proper bounds checks.
- Consider multiple dtypes (e.g.
float16,float32) and accumulation strategies. - Integrate with future tests and benchmarks.
Use case
The Triton gemm kernel will:
- act as a central example of high-performance GPU kernel development in this project,
- enable detailed performance comparisons with PyTorch and CuTe GEMM implementations,
- serve as a building block for Transformer modules and other complex workloads.
Related work
- Existing Triton kernels:
triton_ops.copy,triton_ops.swap. - Triton GEMM tutorials and reference implementations.
Additional context
This issue is a key part of extending the Triton backend to cover BLAS level-3 operations listed in the README, starting with gemm.
Metadata
Metadata
Assignees
Labels
No labels