Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import os
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from torch.utils.cpp_extension import load
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
triton_mx_block_rearrange_2d_K_groups,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs

# Build CUDA kernel directly using torch.utils.cpp_extension.load
mxfp8_cuda = None
try:
# Get the kernel source directory
KERNEL_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
"..",
"..",
"torchao",
"csrc",
"cuda",
"mx_kernels",
)
KERNEL_DIR = os.path.normpath(KERNEL_DIR)

print("Compiling CUDA kernel...")
mxfp8_cuda = load(
name="mxfp8_cuda",
sources=[
os.path.join(KERNEL_DIR, "mxfp8_extension.cpp"),
os.path.join(KERNEL_DIR, "mxfp8_cuda.cu"),
os.path.join(KERNEL_DIR, "mx_block_rearrange_2d_K_groups.cu"),
],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
"-std=c++17",
"-gencode=arch=compute_100,code=sm_100",
],
extra_cflags=["-O3", "-std=c++17"],
verbose=True,
)
print("✓ CUDA kernel compilation successful!")
except (ImportError, RuntimeError) as e:
print(f"⚠ CUDA kernel not available: {e}")
print("The benchmark will only run 'naive' and 'parallel' Triton versions.\n")

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]
num_groups: int
version: str # "naive" or "parallel"


@dataclass(frozen=True)
class ExperimentResult:
time_us: float
mem_bw_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
block_size = 32
input_shapes = [
(5120, 16384 // block_size),
(5120, 131072 // block_size),
(8192, 16384 // block_size),
(8192, 131072 // block_size),
(7168, 16384 // block_size),
(7168, 131072 // block_size),
(2048, 16384 // block_size),
(2048, 131072 // block_size),
]
num_groups = [8]
versions = [
"triton",
"cuda_rowmajor",
"cuda_colmajor",
"cuda_colmajor_vec",
"cuda_colmajor_vec_16B",
"cuda_rowmajor_vec",
"cuda_rowmajor_128x4_vec",
]

configs = []
for shape, groups, version in itertools.product(
input_shapes,
num_groups,
versions,
):
configs.append(
ExperimentConfig(
input_shape=shape,
num_groups=groups,
version=version,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_shape, num_groups, version = (
config.input_shape,
config.num_groups,
config.version,
)
input_tensor = torch.randint(
low=0,
high=256,
size=input_shape,
dtype=torch.uint8,
device=device,
)

M, Kg = input_shape
block_size = 32
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)

# Select which kernel to benchmark based on version
if version == "triton":
kernel_fn = triton_mx_block_rearrange_2d_K_groups
# Triton uses row-major input
kernel_input = input_tensor
elif version == "cuda_rowmajor":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor
# Row-major kernel expects contiguous row-major input
kernel_input = input_tensor.contiguous()
elif version == "cuda_colmajor":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor
# Column-major kernel expects column-major input
# Column-major: same shape (rows, cols) but stride(0)=1, stride(1)=rows
kernel_input = input_tensor.T.contiguous().T
elif version == "cuda_colmajor_vec":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized
# Vectorized column-major kernel also expects column-major input
kernel_input = input_tensor.T.contiguous().T
elif version == "cuda_colmajor_vec_16B":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B
# 16B vectorized column-major kernel also expects column-major input
kernel_input = input_tensor.T.contiguous().T
elif version == "cuda_rowmajor_vec":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor_vectorized
# Row-major vectorized kernel expects contiguous row-major input
kernel_input = input_tensor.contiguous()
elif version == "cuda_rowmajor_128x4_vec":
if mxfp8_cuda is None:
raise RuntimeError("CUDA kernel not available")
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec
# Row-major 128x4 vectorized kernel expects contiguous row-major input
kernel_input = input_tensor.contiguous()
else:
raise ValueError(f"Unknown version: {version}")

# Run kernel to get output shape
out_scales = kernel_fn(
kernel_input,
input_group_offsets,
)

# Benchmark the kernel
# Note: column-major tensors are not "contiguous" in PyTorch's row-major sense,
# but they are valid and have the expected strides for the CUDA kernel
time_us = benchmark_cuda_function_in_microseconds(
kernel_fn,
kernel_input,
input_group_offsets,
)

# Calculate memory bandwidth
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = out_scales.numel() * bytes_per_output_el

mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6)

return ExperimentResult(
time_us=time_us,
mem_bw_gbps=mem_bw_gbps,
)


def print_results(experiments: List[Experiment]):
# Group experiments by input shape
shapes_dict = {}
for exp in experiments:
shape_key = exp.config.input_shape
if shape_key not in shapes_dict:
shapes_dict[shape_key] = {}
shapes_dict[shape_key][exp.config.version] = exp.result

headers = [
"kernel_version",
"input_shape",
"time_us",
"mem_bw_gbps",
"fastest_version",
"speedup_vs_triton",
]

rows = []
for shape, versions in shapes_dict.items():
# Find fastest version for this shape
fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0]

# Get triton baseline time for speedup calculation
triton_time_us = (
versions.get("triton").time_us if "triton" in versions else None
)

# Add rows for each version
for version, result in versions.items():
# Calculate speedup vs triton
speedup_str = ""
if version != "triton":
speedup = triton_time_us / result.time_us
speedup_str = f"{speedup:.2f}x"

rows.append(
[
version,
f"({shape[0]}, {shape[1]})",
f"{result.time_us:.2f}",
round(result.mem_bw_gbps, 3),
fastest_version,
speedup_str,
]
)

print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def get_extensions():
mxfp8_sources = [
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"),
]

# Only add the extension if the source files exist AND we are building for sm100
Expand Down
59 changes: 59 additions & 0 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,62 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
# Check quantized values
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
def test_cuda_mx_block_rearrange_2d_K_groups(
m: int,
total_k: int,
n_groups: int,
):
"""
Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference.
This kernel rearranges E8M0 scales to block-scaled swizzle format for cuBLAS Tmem.
"""
from torchao.prototype import mxfp8_cuda

device = "cuda"
block_size = 32
input_data = torch.randn(m, total_k, device=device)

e8m0_scales, _ = to_mx(
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
input_group_offsets = generate_jagged_offs(
n_groups, total_k, multiple_of=block_size, device=device
)
scale_group_offsets = input_group_offsets // block_size

# Triton reference implementation
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
e8m0_scales,
scale_group_offsets,
)

# CUDA kernel implementation
cuda_out_scales = mxfp8_cuda.mx_block_rearrange_2d_K_groups(
e8m0_scales.view(torch.uint8),
scale_group_offsets,
)

# Check that outputs match
assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), (
"CUDA and Triton blocked scales not equal"
)

# Verify output shape
expected_rows = ((m + 127) // 128) * 128 # Padded to multiple of 128
expected_cols = (
e8m0_scales.size(1) + n_groups * 4
) # Original cols + padding per group
assert cuda_out_scales.shape == (expected_rows, expected_cols), (
f"Output shape mismatch: expected {(expected_rows, expected_cols)}, got {cuda_out_scales.shape}"
)
Loading
Loading