Skip to content

[FEATURE REQUEST] gemm Triton kernel implementation #32

@LoserCheems

Description

@LoserCheems

Problem statement

The BLAS level-3 gemm kernel has no Triton implementation in kernel_course.triton_ops. The README BLAS table shows an empty Triton column for gemm, even though GEMM is the canonical example for high-performance GPU kernels.

Without a Triton gemm kernel:

  • users cannot study or experiment with GEMM implementations in Triton within this project,
  • there is no Triton baseline for performance comparison against PyTorch and CuTe GEMM,
  • the Triton backend lacks a flagship example that ties together tiling, memory layouts, and parallelism.

Proposed solution

Add a Triton implementation of gemm under kernel_course.triton_ops that matches the Python/PyTorch semantics and showcases best practices for GEMM-style kernels.

Concretely:

  • Introduce kernel_course/triton_ops/gemm.py defining a Triton JIT GEMM kernel and Python wrapper.
  • Implement tiled matrix multiplication with configurable block sizes and good memory access patterns.
  • Ensure numerical equivalence with the Python reference across supported dtypes.

Alternatives considered

Relying solely on PyTorch or CuTe for GEMM would:

  • miss a key educational opportunity to present GEMM in Triton,
  • limit performance and implementation comparisons across backends,
  • leave the Triton column incomplete in the README BLAS table.

Implementation details

  • Add kernel_course/triton_ops/gemm.py containing:
    • a Triton kernel with parameters for tiles of A, B, and C,
    • a wrapper to configure grid/block sizes and launch the kernel.
  • Handle arbitrary matrix shapes with proper bounds checks.
  • Consider multiple dtypes (e.g. float16, float32) and accumulation strategies.
  • Integrate with future tests and benchmarks.

Use case

The Triton gemm kernel will:

  • act as a central example of high-performance GPU kernel development in this project,
  • enable detailed performance comparisons with PyTorch and CuTe GEMM implementations,
  • serve as a building block for Transformer modules and other complex workloads.

Related work

  • Existing Triton kernels: triton_ops.copy, triton_ops.swap.
  • Triton GEMM tutorials and reference implementations.

Additional context

This issue is a key part of extending the Triton backend to cover BLAS level-3 operations listed in the README, starting with gemm.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions