diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py new file mode 100644 index 0000000000..ac39cce8dd --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import os +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from torch.utils.cpp_extension import load +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + triton_mx_block_rearrange_2d_K_groups, +) +from torchao.prototype.moe_training.utils import generate_jagged_offs + +# Build CUDA kernel directly using torch.utils.cpp_extension.load +mxfp8_cuda = None +try: + # Get the kernel source directory + KERNEL_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "..", + "..", + "..", + "torchao", + "csrc", + "cuda", + "mx_kernels", + ) + KERNEL_DIR = os.path.normpath(KERNEL_DIR) + + print("Compiling CUDA kernel...") + mxfp8_cuda = load( + name="mxfp8_cuda", + sources=[ + os.path.join(KERNEL_DIR, "mxfp8_extension.cpp"), + os.path.join(KERNEL_DIR, "mxfp8_cuda.cu"), + os.path.join(KERNEL_DIR, "mx_block_rearrange_2d_K_groups.cu"), + ], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-std=c++17", + "-gencode=arch=compute_100,code=sm_100", + ], + extra_cflags=["-O3", "-std=c++17"], + verbose=True, + ) + print("✓ CUDA kernel compilation successful!") +except (ImportError, RuntimeError) as e: + print(f"⚠ CUDA kernel not available: {e}") + print("The benchmark will only run 'naive' and 'parallel' Triton versions.\n") + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + num_groups: int + version: str # "naive" or "parallel" + + +@dataclass(frozen=True) +class ExperimentResult: + time_us: float + mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups. + block_size = 32 + input_shapes = [ + (5120, 16384 // block_size), + (5120, 131072 // block_size), + (8192, 16384 // block_size), + (8192, 131072 // block_size), + (7168, 16384 // block_size), + (7168, 131072 // block_size), + (2048, 16384 // block_size), + (2048, 131072 // block_size), + ] + num_groups = [8] + versions = [ + "triton", + "cuda_rowmajor", + "cuda_colmajor", + "cuda_colmajor_vec", + "cuda_colmajor_vec_16B", + "cuda_rowmajor_vec", + "cuda_rowmajor_128x4_vec", + ] + + configs = [] + for shape, groups, version in itertools.product( + input_shapes, + num_groups, + versions, + ): + configs.append( + ExperimentConfig( + input_shape=shape, + num_groups=groups, + version=version, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + input_shape, num_groups, version = ( + config.input_shape, + config.num_groups, + config.version, + ) + input_tensor = torch.randint( + low=0, + high=256, + size=input_shape, + dtype=torch.uint8, + device=device, + ) + + M, Kg = input_shape + block_size = 32 + input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) + + # Select which kernel to benchmark based on version + if version == "triton": + kernel_fn = triton_mx_block_rearrange_2d_K_groups + # Triton uses row-major input + kernel_input = input_tensor + elif version == "cuda_rowmajor": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor + # Row-major kernel expects contiguous row-major input + kernel_input = input_tensor.contiguous() + elif version == "cuda_colmajor": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor + # Column-major kernel expects column-major input + # Column-major: same shape (rows, cols) but stride(0)=1, stride(1)=rows + kernel_input = input_tensor.T.contiguous().T + elif version == "cuda_colmajor_vec": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized + # Vectorized column-major kernel also expects column-major input + kernel_input = input_tensor.T.contiguous().T + elif version == "cuda_colmajor_vec_16B": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B + # 16B vectorized column-major kernel also expects column-major input + kernel_input = input_tensor.T.contiguous().T + elif version == "cuda_rowmajor_vec": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor_vectorized + # Row-major vectorized kernel expects contiguous row-major input + kernel_input = input_tensor.contiguous() + elif version == "cuda_rowmajor_128x4_vec": + if mxfp8_cuda is None: + raise RuntimeError("CUDA kernel not available") + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec + # Row-major 128x4 vectorized kernel expects contiguous row-major input + kernel_input = input_tensor.contiguous() + else: + raise ValueError(f"Unknown version: {version}") + + # Run kernel to get output shape + out_scales = kernel_fn( + kernel_input, + input_group_offsets, + ) + + # Benchmark the kernel + # Note: column-major tensors are not "contiguous" in PyTorch's row-major sense, + # but they are valid and have the expected strides for the CUDA kernel + time_us = benchmark_cuda_function_in_microseconds( + kernel_fn, + kernel_input, + input_group_offsets, + ) + + # Calculate memory bandwidth + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = out_scales.numel() * bytes_per_output_el + + mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6) + + return ExperimentResult( + time_us=time_us, + mem_bw_gbps=mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + # Group experiments by input shape + shapes_dict = {} + for exp in experiments: + shape_key = exp.config.input_shape + if shape_key not in shapes_dict: + shapes_dict[shape_key] = {} + shapes_dict[shape_key][exp.config.version] = exp.result + + headers = [ + "kernel_version", + "input_shape", + "time_us", + "mem_bw_gbps", + "fastest_version", + "speedup_vs_triton", + ] + + rows = [] + for shape, versions in shapes_dict.items(): + # Find fastest version for this shape + fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0] + + # Get triton baseline time for speedup calculation + triton_time_us = ( + versions.get("triton").time_us if "triton" in versions else None + ) + + # Add rows for each version + for version, result in versions.items(): + # Calculate speedup vs triton + speedup_str = "" + if version != "triton": + speedup = triton_time_us / result.time_us + speedup_str = f"{speedup:.2f}x" + + rows.append( + [ + version, + f"({shape[0]}, {shape[1]})", + f"{result.time_us:.2f}", + round(result.mem_bw_gbps, 3), + fastest_version, + speedup_str, + ] + ) + + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 9d2d7bce1c..9136371a4d 100644 --- a/setup.py +++ b/setup.py @@ -702,6 +702,7 @@ def get_extensions(): mxfp8_sources = [ os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"), os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"), + os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"), ] # Only add the extension if the source files exist AND we are building for sm100 diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index ecd4cefe6a..f75294bcd5 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -354,3 +354,65 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): # Check quantized values torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) +@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) +@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) +def test_cuda_mx_block_rearrange_2d_K_groups( + m: int, + total_k: int, + n_groups: int, +): + """ + Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference. + This kernel rearranges E8M0 scales to block-scaled swizzle format for cuBLAS Tmem. + """ + from torchao.prototype import mxfp8_cuda + + device = "cuda" + block_size = 32 + input_data = torch.randn(m, total_k, device=device) + + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + scale_group_offsets = input_group_offsets // block_size + + # Triton reference implementation + triton_out_scales = triton_mx_block_rearrange_2d_K_groups( + e8m0_scales, + scale_group_offsets, + ) + + # CUDA kernel implementation + cuda_out_scales = mxfp8_cuda.mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + e8m0_scales.view(torch.uint8), + scale_group_offsets, + ) + + # Check that outputs match + assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), ( + "CUDA and Triton blocked scales not equal" + ) + + # Verify output shape + expected_rows = ((m + 127) // 128) * 128 # Padded to multiple of 128 + expected_cols = ( + e8m0_scales.size(1) + n_groups * 4 + ) # Original cols + padding per group + assert cuda_out_scales.shape == ( + expected_rows, + expected_cols, + ), ( + f"Output shape mismatch: expected {(expected_rows, expected_cols)}, got {cuda_out_scales.shape}" + ) diff --git a/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu b/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu new file mode 100644 index 0000000000..7fb45fb573 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu @@ -0,0 +1,965 @@ +#include +#include +#include +#include + +#define BLOCK_ROWS 128 +#define BLOCK_COLS 4 +#define BLOCK_ROWS_LARGE 512 +#define BYTES_PER_THREAD 16 +#define SCALE_FACTOR_ROWS 128 + +__device__ __forceinline__ int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +__device__ void find_group_and_local_offset( + int col_block_pid, + const int32_t* __restrict__ input_group_end_offsets, + int num_groups, + int cols_per_block, + int* __restrict__ smem_cumsum, + int& group_id, + int& local_col_block +) { + if (threadIdx.x == 0) { + int cumsum = 0; + for (int g = 0; g < num_groups; g++) { + int input_group_start = (g > 0) ? input_group_end_offsets[g - 1] : 0; + int input_group_end = input_group_end_offsets[g]; + int group_size = input_group_end - input_group_start; + int num_col_blocks = ceil_div(group_size, cols_per_block); + cumsum += num_col_blocks; + smem_cumsum[g] = cumsum; + } + } + __syncthreads(); + + group_id = 0; + int cumsum_before = 0; + for (int g = 0; g < num_groups; g++) { + int cumsum_at_g = smem_cumsum[g]; + if (col_block_pid < cumsum_at_g) { + group_id = g; + local_col_block = col_block_pid - cumsum_before; + break; + } + cumsum_before = cumsum_at_g; + } +} + +__device__ __forceinline__ int compute_output_group_start_col( + int group_id, + const int32_t* input_group_end_offsets, + int num_groups, + int padding_size +) { + int start_idx = 0; + for (int i = 0; i < group_id; i++) { + int prev_offset = (i > 0) ? input_group_end_offsets[i - 1] : 0; + int curr_offset = input_group_end_offsets[i]; + int group_size = curr_offset - prev_offset; + int padded_size = ceil_div(group_size, padding_size) * padding_size; + start_idx += padded_size; + } + return start_idx; +} + +__device__ __forceinline__ int compute_swizzled_index(int row, int col) { + int r_div_32 = row / 32; + int r_mod_32 = row % 32; + return r_mod_32 * 16 + r_div_32 * 4 + col; +} + +// Row-major kernel: Input tensor has cols contiguous +__global__ void mx_block_rearrange_2d_K_groups_rowmajor_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + + __shared__ __align__(16) uint8_t smem_block[BLOCK_ROWS * BLOCK_COLS]; + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + BLOCK_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * BLOCK_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + int input_row = row_block_pid * BLOCK_ROWS + tid; + int cols_remaining = input_group_end_col - curr_input_start_col; + int cols_to_load = min(BLOCK_COLS, cols_remaining); + + uint32_t row_data = 0; + if (input_row < scale_rows && curr_input_start_col < input_group_end_col) { + int input_offset = input_row * scales_stride_dim0 + curr_input_start_col; + const uint8_t* input_ptr = scales_ptr + input_offset; + + uintptr_t ptr_addr = reinterpret_cast(input_ptr); + if (cols_to_load >= 4 && ptr_addr % 4 == 0 && curr_input_start_col + 4 <= input_group_end_col) { + row_data = __ldg(reinterpret_cast(input_ptr)); + } else { + uint8_t* row_bytes = reinterpret_cast(&row_data); + for (int i = 0; i < cols_to_load && (curr_input_start_col + i) < input_group_end_col; i++) { + row_bytes[i] = __ldg(input_ptr + i); + } + } + } + + uint8_t* row_bytes = reinterpret_cast(&row_data); + #pragma unroll + for (int col = 0; col < BLOCK_COLS; col++) { + int swizzled_idx = compute_swizzled_index(tid, col); + smem_block[swizzled_idx] = row_bytes[col]; + } + + __syncthreads(); + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * output_stride_per_block; + + int offset_in_group = row_block_pid * stride_per_row_of_blocks_in_group + + local_col_block * output_stride_per_block; + int final_offset = out_group_base_offset + offset_in_group; + + uint8_t* output_ptr = output_scales_ptr + final_offset + tid * BLOCK_COLS; + uintptr_t out_ptr_addr = reinterpret_cast(output_ptr); + + if (out_ptr_addr % 4 == 0 && cols_to_load >= 4) { + *reinterpret_cast(output_ptr) = + *reinterpret_cast(&smem_block[tid * BLOCK_COLS]); + } else { + const uint8_t* smem_ptr = &smem_block[tid * BLOCK_COLS]; + #pragma unroll + for (int i = 0; i < cols_to_load; i++) { + output_ptr[i] = smem_ptr[i]; + } + } +} + +// Column-major kernel: Input tensor has rows contiguous +__global__ void mx_block_rearrange_2d_K_groups_colmajor_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim0, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + + __shared__ __align__(16) uint8_t smem_swizzled[BLOCK_ROWS * BLOCK_COLS]; + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + BLOCK_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * BLOCK_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + int cols_remaining = input_group_end_col - curr_input_start_col; + int row_in_block = tid; + int global_row = row_block_pid * BLOCK_ROWS + row_in_block; + + uint32_t packed_scales = 0; + uint8_t* local_vals = reinterpret_cast(&packed_scales); + + if (global_row < scale_rows) { + #pragma unroll + for (int c = 0; c < BLOCK_COLS; ++c) { + if (c < cols_remaining) { + int global_col = curr_input_start_col + c; + size_t offset = static_cast(global_col) * scales_stride_dim1 + global_row; + local_vals[c] = __ldg(scales_ptr + offset); + } + } + } + + int r_div_32 = row_in_block >> 5; + int r_mod_32 = row_in_block & 31; + int smem_offset = (r_mod_32 << 4) + (r_div_32 << 2); + + *reinterpret_cast(&smem_swizzled[smem_offset]) = packed_scales; + + __syncthreads(); + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * output_stride_per_block; + + int offset_in_group = row_block_pid * stride_per_row_of_blocks_in_group + + local_col_block * output_stride_per_block; + int final_offset = out_group_base_offset + offset_in_group; + + uint8_t* output_ptr = output_scales_ptr + final_offset + tid * BLOCK_COLS; + + *reinterpret_cast(output_ptr) = + *reinterpret_cast(&smem_swizzled[tid * BLOCK_COLS]); +} + +// Column-major vectorized kernel: 4 warps, each processing one column with uint32_t loads +__global__ void mx_block_rearrange_2d_K_groups_colmajor_vectorized_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane_id = tid & 31; + + __shared__ __align__(16) uint8_t smem_block[BLOCK_ROWS * BLOCK_COLS]; + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + BLOCK_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * BLOCK_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + int cols_remaining = input_group_end_col - curr_input_start_col; + int global_row_base = row_block_pid * BLOCK_ROWS; + + uint32_t loaded_data = 0; + int col_idx = warp_id; + + if (col_idx < cols_remaining) { + int global_col = curr_input_start_col + col_idx; + int row_start = global_row_base + lane_id * 4; + + const uint8_t* col_ptr = scales_ptr + + static_cast(global_col) * scales_stride_dim1; + + if (row_start + 3 < scale_rows) { + loaded_data = __ldg(reinterpret_cast(col_ptr + row_start)); + } else if (row_start < scale_rows) { + uint8_t* bytes = reinterpret_cast(&loaded_data); + #pragma unroll + for (int i = 0; i < 4; i++) { + if (row_start + i < scale_rows) { + bytes[i] = __ldg(col_ptr + row_start + i); + } + } + } + } + + uint8_t* bytes = reinterpret_cast(&loaded_data); + + #pragma unroll + for (int i = 0; i < 4; i++) { + int row_in_block = lane_id * 4 + i; + int r_div_32 = row_in_block >> 5; + int r_mod_32 = row_in_block & 31; + int swizzle_idx = (r_mod_32 << 4) + (r_div_32 << 2) + col_idx; + smem_block[swizzle_idx] = bytes[i]; + } + + __syncthreads(); + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * output_stride_per_block; + + int offset_in_group = row_block_pid * stride_per_row_of_blocks_in_group + + local_col_block * output_stride_per_block; + int final_offset = out_group_base_offset + offset_in_group; + + uint8_t* output_ptr = output_scales_ptr + final_offset + tid * BLOCK_COLS; + + *reinterpret_cast(output_ptr) = + *reinterpret_cast(&smem_block[tid * BLOCK_COLS]); +} + +// Column-major 16B vectorized kernel: 512-row blocks, uint4 loads (16 bytes per thread) +__global__ void mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane_id = tid & 31; + + __shared__ __align__(16) uint8_t smem_block[BLOCK_ROWS_LARGE * BLOCK_COLS]; + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + BLOCK_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * BLOCK_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + __syncthreads(); + + int cols_remaining = input_group_end_col - curr_input_start_col; + int global_row_base = row_block_pid * BLOCK_ROWS_LARGE; + + uint4 loaded_data = make_uint4(0, 0, 0, 0); + int col_idx = warp_id; + + if (col_idx < cols_remaining) { + int global_col = curr_input_start_col + col_idx; + int row_start = global_row_base + lane_id * BYTES_PER_THREAD; + + const uint8_t* col_ptr = scales_ptr + + static_cast(global_col) * scales_stride_dim1; + + if (row_start + BYTES_PER_THREAD - 1 < scale_rows) { + loaded_data = __ldg(reinterpret_cast(col_ptr + row_start)); + } else if (row_start < scale_rows) { + uint8_t* bytes = reinterpret_cast(&loaded_data); + #pragma unroll + for (int i = 0; i < BYTES_PER_THREAD; i++) { + if (row_start + i < scale_rows) { + bytes[i] = __ldg(col_ptr + row_start + i); + } + } + } + } + + uint8_t* bytes = reinterpret_cast(&loaded_data); + + #pragma unroll + for (int i = 0; i < BYTES_PER_THREAD; i++) { + int row_in_block = lane_id * BYTES_PER_THREAD + i; + int tile_idx = row_in_block / SCALE_FACTOR_ROWS; + int row_in_tile = row_in_block % SCALE_FACTOR_ROWS; + int tile_base_offset = tile_idx * SCALE_FACTOR_ROWS * BLOCK_COLS; + + int r_div_32 = row_in_tile >> 5; + int r_mod_32 = row_in_tile & 31; + int swizzle_idx = (r_mod_32 << 4) + (r_div_32 << 2) + col_idx; + + smem_block[tile_base_offset + swizzle_idx] = bytes[i]; + } + + __syncthreads(); + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + + constexpr int TILE_SIZE = SCALE_FACTOR_ROWS * BLOCK_COLS; + int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * TILE_SIZE; + + #pragma unroll + for (int r = 0; r < 4; r++) { + int row_idx = tid + r * 128; + int tile_idx = row_idx >> 7; + int row_in_tile = row_idx & 127; + int actual_row_block = row_block_pid * 4 + tile_idx; + + int offset_in_group = actual_row_block * stride_per_row_of_blocks_in_group + + local_col_block * TILE_SIZE; + int final_offset = out_group_base_offset + offset_in_group; + + uint8_t* output_ptr = output_scales_ptr + final_offset + row_in_tile * BLOCK_COLS; + + *reinterpret_cast(output_ptr) = + *reinterpret_cast(&smem_block[row_idx * BLOCK_COLS]); + } +} + +// Row-major vectorized kernel: 512-row blocks, uint32_t loads per row +__global__ void mx_block_rearrange_2d_K_groups_rowmajor_vectorized_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + + __shared__ __align__(16) uint8_t smem_block[BLOCK_ROWS_LARGE * BLOCK_COLS]; + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + BLOCK_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * BLOCK_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + __syncthreads(); + + int cols_remaining = input_group_end_col - curr_input_start_col; + int cols_to_load = min(BLOCK_COLS, cols_remaining); + int global_row_base = row_block_pid * BLOCK_ROWS_LARGE; + + #pragma unroll + for (int r = 0; r < 4; r++) { + int row_idx = tid + r * 128; + int global_row = global_row_base + row_idx; + + uint32_t row_data = 0; + + if (global_row < scale_rows) { + const uint8_t* row_ptr = scales_ptr + + static_cast(global_row) * scales_stride_dim0 + curr_input_start_col; + + uintptr_t ptr_addr = reinterpret_cast(row_ptr); + bool aligned = (ptr_addr % 4 == 0); + + if (cols_to_load == 4 && aligned) { + row_data = __ldg(reinterpret_cast(row_ptr)); + } else { + uint8_t* bytes = reinterpret_cast(&row_data); + #pragma unroll + for (int c = 0; c < BLOCK_COLS; c++) { + if (c < cols_to_load) { + bytes[c] = __ldg(row_ptr + c); + } + } + } + } + + uint8_t* bytes = reinterpret_cast(&row_data); + + int tile_idx = row_idx >> 7; + int row_in_tile = row_idx & 127; + int tile_base_offset = tile_idx * SCALE_FACTOR_ROWS * BLOCK_COLS; + + int r_div_32 = row_in_tile >> 5; + int r_mod_32 = row_in_tile & 31; + int swizzle_base = (r_mod_32 << 4) + (r_div_32 << 2); + + #pragma unroll + for (int c = 0; c < BLOCK_COLS; c++) { + smem_block[tile_base_offset + swizzle_base + c] = bytes[c]; + } + } + + __syncthreads(); + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + + constexpr int TILE_SIZE = SCALE_FACTOR_ROWS * BLOCK_COLS; + int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * TILE_SIZE; + + #pragma unroll + for (int r = 0; r < 4; r++) { + int row_idx = tid + r * 128; + int tile_idx = row_idx >> 7; + int row_in_tile = row_idx & 127; + int actual_row_block = row_block_pid * 4 + tile_idx; + + int offset_in_group = actual_row_block * stride_per_row_of_blocks_in_group + + local_col_block * TILE_SIZE; + int final_offset = out_group_base_offset + offset_in_group; + + uint8_t* output_ptr = output_scales_ptr + final_offset + row_in_tile * BLOCK_COLS; + + *reinterpret_cast(output_ptr) = + *reinterpret_cast(&smem_block[row_idx * BLOCK_COLS]); + } +} + +#define MAX_COLS 64 // Maximum columns per threadblock (4 threads wide, 16 bytes per thread) + +__global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + const int col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + + // Single shared memory buffer for swizzled output (8KB instead of 16KB) + // Data is written directly in swizzled format during load + __shared__ __align__(16) uint8_t smem[BLOCK_ROWS * MAX_COLS]; // 128 * 64 = 8KB + __shared__ int smem_cumsum[32]; + __shared__ int output_group_start_col; + + int group_id, local_col_block; + find_group_and_local_offset( + col_block_pid, + input_group_end_offsets, + num_groups, + MAX_COLS, + smem_cumsum, + group_id, + local_col_block + ); + + int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + int input_group_end_col = input_group_end_offsets[group_id]; + int curr_input_start_col = input_group_start_col + local_col_block * MAX_COLS; + + if (curr_input_start_col >= input_group_end_col) { + return; + } + + if (tid == 0) { + output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + } + + __syncthreads(); + + int cols_remaining = input_group_end_col - curr_input_start_col; + int cols_to_load = min(MAX_COLS, cols_remaining); + int global_row_base = row_block_pid * BLOCK_ROWS; + + // Thread layout: 512 threads = 128 rows × 4 threads per row + int row_idx = tid / BLOCK_COLS; // 0-127 + int col_idx = tid % BLOCK_COLS; // 0-3 + int global_row = global_row_base + row_idx; + + // Compute swizzle base offset for this thread's row + int r_div_32 = row_idx >> 5; // row / 32 + int r_mod_32 = row_idx & 31; // row % 32 + int swizzle_base = (r_mod_32 << 4) + (r_div_32 << 2); // r_mod_32 * 16 + r_div_32 * 4 + int thread_col_start = col_idx * 16; // 0, 16, 32, or 48 + + // ============================================================ + // PHASE 1: Load from GMEM directly to SMEM in swizzled format + // Each thread loads 16 bytes from GMEM using vectorized load, + // then scatter-writes to swizzled positions in SMEM. + uint4 data = make_uint4(0, 0, 0, 0); + + if (global_row < scale_rows && thread_col_start < cols_to_load) + { + const uint8_t* row_ptr = scales_ptr + + static_cast(global_row) * scales_stride_dim0 + curr_input_start_col; + + uintptr_t gmem_addr = reinterpret_cast(row_ptr + thread_col_start); + bool aligned = (gmem_addr % 16 == 0); + + // Load 16 bytes from GMEM (vectorized if aligned and full) + if (thread_col_start + 16 <= cols_to_load && aligned) + { + data = __ldg(reinterpret_cast(row_ptr + thread_col_start)); + } + else + { + // Partial/unaligned load + uint8_t* bytes = reinterpret_cast(&data); + int bytes_to_load = min(16, cols_to_load - thread_col_start); + #pragma unroll + for (int i = 0; i < 16; i++) { + if (i < bytes_to_load) { + bytes[i] = __ldg(row_ptr + thread_col_start + i); + } + } + } + } + // else: data remains zero (padding rows or columns beyond cols_to_load) + + // Scatter-write to swizzled SMEM positions using vectorized uint32 writes + // Each group of 4 columns within a tile are contiguous in SMEM + // So we can write 4 bytes at a time (4 tiles × 1 uint32 write = 4 writes total) + uint32_t* data32 = reinterpret_cast(&data); + int first_tile_idx = thread_col_start >> 2; // thread_col_start / 4 + + #pragma unroll + for (int t = 0; t < 4; t++) { + int tile_idx = first_tile_idx + t; + int tile_base = tile_idx * SCALE_FACTOR_ROWS * BLOCK_COLS; // tile_idx * 128 * 4 + int swizzled_idx = tile_base + swizzle_base; + *reinterpret_cast(&smem[swizzled_idx]) = data32[t]; + } + + __syncthreads(); + + // PHASE 2: Store from SMEM to GMEM, data already in swizzled format, can do a direct copy + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_4col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + + constexpr int TILE_SIZE = SCALE_FACTOR_ROWS * BLOCK_COLS; // 128 * 4 = 512 + int stride_per_row_of_4col_blocks = num_4col_blocks_in_group * TILE_SIZE; + + // tiles_before_this_block: 4-col tiles that came before this 64-col block + int tiles_before_this_block = local_col_block * (MAX_COLS / BLOCK_COLS); + + // Base output pointer for this threadblock + uint8_t* out_base = output_scales_ptr + out_group_base_offset + + row_block_pid * stride_per_row_of_4col_blocks + + tiles_before_this_block * TILE_SIZE; + + // Number of 4-column tiles in this threadblock (max 16 for 64 columns) + int num_tiles_this_block = ceil_div(cols_to_load, BLOCK_COLS); + int bytes_to_copy = num_tiles_this_block * TILE_SIZE; + + // Each thread copies 16 bytes using uint4 + // Thread tid copies bytes [tid*16, tid*16+15] + int byte_offset = tid * 16; + if (byte_offset < bytes_to_copy) + { + *reinterpret_cast(out_base + byte_offset) = + *reinterpret_cast(&smem[byte_offset]); + } +} + +namespace mxfp8 { + +void launch_mx_block_rearrange_2d_K_groups_rowmajor( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS; + int total_col_blocks = (scale_cols + BLOCK_COLS - 1) / BLOCK_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(128); + + mx_block_rearrange_2d_K_groups_rowmajor_kernel<<>>( + scales_ptr, + scales_stride_dim0, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +void launch_mx_block_rearrange_2d_K_groups_colmajor( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS; + int total_col_blocks = (scale_cols + BLOCK_COLS - 1) / BLOCK_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(128); + + mx_block_rearrange_2d_K_groups_colmajor_kernel<<>>( + scales_ptr, + scales_stride_dim0, + scales_stride_dim1, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +void launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized( + const uint8_t* scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS; + int total_col_blocks = (scale_cols + BLOCK_COLS - 1) / BLOCK_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(128); + + mx_block_rearrange_2d_K_groups_colmajor_vectorized_kernel<<>>( + scales_ptr, + scales_stride_dim1, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +void launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B( + const uint8_t* scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS_LARGE - 1) / BLOCK_ROWS_LARGE; + int output_stride_per_block = BLOCK_ROWS_LARGE * BLOCK_COLS; + int total_col_blocks = (scale_cols + BLOCK_COLS - 1) / BLOCK_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(128); + + mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B_kernel<<>>( + scales_ptr, + scales_stride_dim1, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_vectorized( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS_LARGE - 1) / BLOCK_ROWS_LARGE; + int output_stride_per_block = BLOCK_ROWS_LARGE * BLOCK_COLS; + int total_col_blocks = (scale_cols + BLOCK_COLS - 1) / BLOCK_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(128); + + mx_block_rearrange_2d_K_groups_rowmajor_vectorized_kernel<<>>( + scales_ptr, + scales_stride_dim0, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS; + // Each col_block handles MAX_COLS (128) columns + int total_col_blocks = (scale_cols + MAX_COLS - 1) / MAX_COLS + num_groups; + + dim3 grid(total_col_blocks, num_row_blocks); + dim3 block(512); // 512 threads for 128x128 block + + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_kernel<<>>( + scales_ptr, + scales_stride_dim0, + scale_rows, + scale_cols, + padded_rows, + input_group_end_offsets, + output_scales_ptr, + output_stride_per_block, + num_groups + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error (128x4 vec): %s\n", cudaGetErrorString(err)); + } +} + +} // namespace mxfp8 diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp index d445fcad4d..80f360fd47 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -25,10 +25,76 @@ void mxfp8_quantize_3d_cuda(const torch::Tensor &input, const std::string &fp8_format, const std::string &scaling_mode); +void launch_mx_block_rearrange_2d_K_groups_rowmajor( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + +void launch_mx_block_rearrange_2d_K_groups_colmajor( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + +void launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized( + const uint8_t* scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + +void launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B( + const uint8_t* scales_ptr, + int scales_stride_dim1, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_vectorized( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + cudaStream_t stream); + // Helper for tensor validation void check_cuda_tensor(const torch::Tensor &t, const char *name) { TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); - TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); } // Helper to validate FP8 format @@ -177,6 +243,369 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, return std::make_tuple(output_colwise, scales_colwise); } +// Python wrapper for mx_block_rearrange_2d_K_groups (row-major input) +torch::Tensor mx_block_rearrange_2d_K_groups_rowmajor( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.is_contiguous(), "scales_tensor must be contiguous (row-major)"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed + const int BLOCK_ROWS = 128; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + const int padded_rows = num_row_blocks * BLOCK_ROWS; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch row-major kernel + launch_mx_block_rearrange_2d_K_groups_rowmajor( + scales_ptr, + scales_tensor.stride(0), + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +// Python wrapper for mx_block_rearrange_2d_K_groups (column-major input) +torch::Tensor mx_block_rearrange_2d_K_groups_colmajor( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed + const int BLOCK_ROWS = 128; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + const int padded_rows = num_row_blocks * BLOCK_ROWS; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch column-major kernel + launch_mx_block_rearrange_2d_K_groups_colmajor( + scales_ptr, + scales_tensor.stride(0), + scales_tensor.stride(1), + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +// Python wrapper for mx_block_rearrange_2d_K_groups (column-major input, vectorized loads) +torch::Tensor mx_block_rearrange_2d_K_groups_colmajor_vectorized( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed + const int BLOCK_ROWS = 128; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + const int padded_rows = num_row_blocks * BLOCK_ROWS; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch column-major vectorized kernel + launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized( + scales_ptr, + scales_tensor.stride(1), // Only need stride_dim1 for vectorized kernel + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +// Python wrapper for mx_block_rearrange_2d_K_groups (column-major input, 16-byte vectorized loads) +torch::Tensor mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed - uses larger 512-row blocks + const int BLOCK_ROWS_LARGE = 512; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS_LARGE - 1) / BLOCK_ROWS_LARGE; + const int padded_rows = num_row_blocks * BLOCK_ROWS_LARGE; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch column-major vectorized 16B kernel + launch_mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B( + scales_ptr, + scales_tensor.stride(1), // Only need stride_dim1 for vectorized kernel + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +// Python wrapper for mx_block_rearrange_2d_K_groups (row-major input, vectorized loads) +torch::Tensor mx_block_rearrange_2d_K_groups_rowmajor_vectorized( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.is_contiguous(), "scales_tensor must be contiguous (row-major)"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed - uses larger 512-row blocks + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + 128 - 1) / 128; + const int padded_rows = num_row_blocks * 128; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch row-major vectorized kernel + launch_mx_block_rearrange_2d_K_groups_rowmajor_vectorized( + scales_ptr, + scales_tensor.stride(0), + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +// Python wrapper for mx_block_rearrange_2d_K_groups (row-major input, 128x4 vectorized) +torch::Tensor mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + torch::Tensor scales_tensor, + torch::Tensor input_group_end_offsets) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.is_contiguous(), "scales_tensor must be contiguous (row-major)"); + TORCH_CHECK(scales_tensor.scalar_type() == torch::kUInt8 || + scales_tensor.scalar_type() == torch::kFloat8_e8m0fnu, + "scales_tensor must be uint8 or e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == torch::kInt32, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed - uses 128-row blocks, processing up to 64 columns at a time + const int BLOCK_ROWS = 128; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + const int padded_rows = num_row_blocks * BLOCK_ROWS; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = torch::zeros({padded_rows, padded_cols}, + torch::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers + const uint8_t* scales_ptr = scales_tensor.data_ptr(); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = output.data_ptr(); + + // Launch row-major 128x4 vectorized kernel + launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + scales_ptr, + scales_tensor.stride(0), + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + at::cuda::getCurrentCUDAStream()); + + return output; +} + } // namespace mxfp8 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -192,4 +621,40 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input"), py::arg("scale_dim_n") = 32, py::arg("fp8_format") = "e4m3", py::arg("scaling_mode") = "floor"); + + m.def("mx_block_rearrange_2d_K_groups_rowmajor", + &mxfp8::mx_block_rearrange_2d_K_groups_rowmajor, + "Rearrange E8M0 scales to block-scaled swizzle format (row-major input)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); + + m.def("mx_block_rearrange_2d_K_groups_colmajor", + &mxfp8::mx_block_rearrange_2d_K_groups_colmajor, + "Rearrange E8M0 scales to block-scaled swizzle format (column-major input)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); + + m.def("mx_block_rearrange_2d_K_groups_colmajor_vectorized", + &mxfp8::mx_block_rearrange_2d_K_groups_colmajor_vectorized, + "Rearrange E8M0 scales to block-scaled swizzle format (column-major input, vectorized loads)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); + + m.def("mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B", + &mxfp8::mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B, + "Rearrange E8M0 scales to block-scaled swizzle format (column-major input, 16-byte vectorized loads, 512-row blocks)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); + + m.def("mx_block_rearrange_2d_K_groups_rowmajor_vectorized", + &mxfp8::mx_block_rearrange_2d_K_groups_rowmajor_vectorized, + "Rearrange E8M0 scales to block-scaled swizzle format (row-major input, vectorized loads, 512-row blocks)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); + + m.def("mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec", + &mxfp8::mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec, + "Rearrange E8M0 scales to block-scaled swizzle format (row-major input, 128x4 vectorized, 128-row blocks)", + py::arg("scales_tensor"), + py::arg("input_group_end_offsets")); } diff --git a/torchao/csrc/cuda/mx_kernels/test_mx_block_rearrange_standalone.py b/torchao/csrc/cuda/mx_kernels/test_mx_block_rearrange_standalone.py new file mode 100644 index 0000000000..dc94b27c41 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/test_mx_block_rearrange_standalone.py @@ -0,0 +1,308 @@ +""" +Standalone test for mx_block_rearrange_2d_K_groups CUDA kernel. +Tests both row-major and column-major kernel variants. +Uses torch.utils.cpp_extension.load for quick compilation and iteration. + +Usage: + python test_mx_block_rearrange_standalone.py +""" + +import os +import sys + +import torch +from torch.utils.cpp_extension import load + +# Get the directory where this script is located +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Load the CUDA extension +print("Compiling CUDA kernel...") +mx_block_rearrange = load( + name="mx_block_rearrange_2d_K_groups", + sources=[ + os.path.join(SCRIPT_DIR, "mxfp8_extension.cpp"), + os.path.join(SCRIPT_DIR, "mxfp8_cuda.cu"), + os.path.join(SCRIPT_DIR, "mx_block_rearrange_2d_K_groups.cu"), + ], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-std=c++17", + "-gencode=arch=compute_100,code=sm_100", + "-Xptxas=-v", # Show register usage per kernel + ], + extra_cflags=["-O3", "-std=c++17"], + verbose=True, +) + +print("Compilation successful!") + + +def benchmark_kernel(kernel_fn, *args, warmup=10, iterations=100): + """Benchmark a kernel function and return average time in microseconds.""" + # Warmup + for _ in range(warmup): + kernel_fn(*args) + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(iterations): + kernel_fn(*args) + end_event.record() + + torch.cuda.synchronize() + elapsed_ms = start_event.elapsed_time(end_event) + return (elapsed_ms / iterations) * 1000 # Convert to microseconds + + +def test_kernel(): + print("\n" + "=" * 80) + print( + "Testing mx_block_rearrange_2d_K_groups kernels (row-major, column-major, vectorized)" + ) + print("=" * 80) + + # Try importing the Triton reference implementation + try: + ao_root = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..", "..", "..")) + sys.path.insert(0, ao_root) + + from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + triton_mx_block_rearrange_2d_K_groups, + ) + from torchao.prototype.moe_training.utils import generate_jagged_offs + from torchao.prototype.mx_formats.mx_tensor import to_mx + + has_triton = True + print("Triton reference implementation available") + except ImportError as e: + print(f"WARNING: Triton reference not available: {e}") + has_triton = False + + # Test parameters - use larger size for meaningful benchmarks + device = "cuda" + m, total_k = 7168, 131072 + n_groups = 8 + block_size = 32 + + print("\nTest configuration:") + print(f" Matrix size: {m} x {total_k}") + print(f" Number of groups: {n_groups}") + + # Generate test data + print("\nGenerating test data...") + torch.manual_seed(42) + input_data = torch.randn(m, total_k, device=device) + + if has_triton: + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + scale_group_offsets = input_group_offsets // block_size + + print(f" Scales shape: {e8m0_scales.shape}") + else: + return False + + rows, cols = e8m0_scales.shape + + # Prepare row-major input (default contiguous) + e8m0_scales_row_major = e8m0_scales.contiguous() + assert e8m0_scales_row_major.is_contiguous(), "Row-major input should be contiguous" + + # Prepare column-major input (same shape, different memory layout) + e8m0_scales_col_major = e8m0_scales.T.contiguous().T + assert e8m0_scales_col_major.shape == e8m0_scales.shape, "Shape should be preserved" + assert e8m0_scales_col_major.stride() == ( + 1, + rows, + ), ( + f"Expected column-major strides (1, {rows}), got {e8m0_scales_col_major.stride()}" + ) + + # ------------------------------------------------------------------------- + # Test Row-Major CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA row-major kernel...") + print( + f" Input shape: {e8m0_scales_row_major.shape}, strides: {e8m0_scales_row_major.stride()}" + ) + cuda_rowmajor_out = mx_block_rearrange.mx_block_rearrange_2d_K_groups_rowmajor( + e8m0_scales_row_major.view(torch.uint8), + scale_group_offsets, + ) + print("CUDA row-major kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Column-Major CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA column-major kernel...") + print( + f" Input shape: {e8m0_scales_col_major.shape}, strides: {e8m0_scales_col_major.stride()}" + ) + cuda_colmajor_out = mx_block_rearrange.mx_block_rearrange_2d_K_groups_colmajor( + e8m0_scales_col_major.view(torch.uint8), + scale_group_offsets, + ) + print("CUDA column-major kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Column-Major Vectorized CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA column-major vectorized kernel...") + print( + f" Input shape: {e8m0_scales_col_major.shape}, strides: {e8m0_scales_col_major.stride()}" + ) + cuda_colmajor_vec_out = ( + mx_block_rearrange.mx_block_rearrange_2d_K_groups_colmajor_vectorized( + e8m0_scales_col_major.view(torch.uint8), + scale_group_offsets, + ) + ) + print("CUDA column-major vectorized kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Column-Major Vectorized 16B CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA column-major vectorized 16B kernel...") + print( + f" Input shape: {e8m0_scales_col_major.shape}, strides: {e8m0_scales_col_major.stride()}" + ) + cuda_colmajor_vec_16B_out = ( + mx_block_rearrange.mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B( + e8m0_scales_col_major.view(torch.uint8), + scale_group_offsets, + ) + ) + print("CUDA column-major vectorized 16B kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Row-Major Vectorized CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA row-major vectorized kernel...") + print( + f" Input shape: {e8m0_scales_row_major.shape}, strides: {e8m0_scales_row_major.stride()}" + ) + cuda_rowmajor_vec_out = ( + mx_block_rearrange.mx_block_rearrange_2d_K_groups_rowmajor_vectorized( + e8m0_scales_row_major.view(torch.uint8), + scale_group_offsets, + ) + ) + print("CUDA row-major vectorized kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Row-Major 128x4 Vectorized CUDA Kernel + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running CUDA row-major 128x4 vectorized kernel...") + print( + f" Input shape: {e8m0_scales_row_major.shape}, strides: {e8m0_scales_row_major.stride()}" + ) + cuda_rowmajor_128x4_vec_out = ( + mx_block_rearrange.mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec( + e8m0_scales_row_major.view(torch.uint8), + scale_group_offsets, + ) + ) + print("CUDA row-major 128x4 vectorized kernel completed successfully") + + # ------------------------------------------------------------------------- + # Test Triton Reference + # ------------------------------------------------------------------------- + print("\n" + "-" * 80) + print("Running Triton reference kernel...") + triton_out = triton_mx_block_rearrange_2d_K_groups( + e8m0_scales, + scale_group_offsets, + ) + print("Triton kernel completed successfully") + + # ------------------------------------------------------------------------- + # Verify Correctness + # ------------------------------------------------------------------------- + print("\nVerifying correctness...") + cuda_rowmajor_out_e8m0 = cuda_rowmajor_out.view(torch.float8_e8m0fnu) + cuda_colmajor_out_e8m0 = cuda_colmajor_out.view(torch.float8_e8m0fnu) + cuda_colmajor_vec_out_e8m0 = cuda_colmajor_vec_out.view(torch.float8_e8m0fnu) + + all_correct = True + + if not torch.equal(triton_out, cuda_rowmajor_out_e8m0): + print("FAILED: CUDA row-major and Triton outputs differ!") + all_correct = False + else: + print("PASSED: CUDA row-major matches Triton") + + if not torch.equal(triton_out, cuda_colmajor_out_e8m0): + print("FAILED: CUDA column-major and Triton outputs differ!") + all_correct = False + else: + print("PASSED: CUDA column-major matches Triton") + + if not torch.equal(triton_out, cuda_colmajor_vec_out_e8m0): + print("FAILED: CUDA column-major vectorized and Triton outputs differ!") + all_correct = False + else: + print("PASSED: CUDA column-major vectorized matches Triton") + + cuda_colmajor_vec_16B_out_e8m0 = cuda_colmajor_vec_16B_out.view( + torch.float8_e8m0fnu + ) + if not torch.equal(triton_out, cuda_colmajor_vec_16B_out_e8m0): + print("FAILED: CUDA column-major vectorized 16B and Triton outputs differ!") + all_correct = False + else: + print("PASSED: CUDA column-major vectorized 16B matches Triton") + + cuda_rowmajor_vec_out_e8m0 = cuda_rowmajor_vec_out.view(torch.float8_e8m0fnu) + if not torch.equal(triton_out, cuda_rowmajor_vec_out_e8m0): + print("FAILED: CUDA row-major vectorized and Triton outputs differ!") + all_correct = False + else: + print("PASSED: CUDA row-major vectorized matches Triton") + + cuda_rowmajor_128x4_vec_out_e8m0 = cuda_rowmajor_128x4_vec_out.view( + torch.float8_e8m0fnu + ) + if not torch.equal(triton_out, cuda_rowmajor_128x4_vec_out_e8m0): + print("FAILED: CUDA row-major 128x4 vectorized and Triton outputs differ!") + # Print debug info for differences + diff_mask = triton_out != cuda_rowmajor_128x4_vec_out_e8m0 + num_diffs = diff_mask.sum().item() + print(f" Number of differences: {num_diffs} / {triton_out.numel()}") + all_correct = False + else: + print("PASSED: CUDA row-major 128x4 vectorized matches Triton") + + if not all_correct: + return False + + print("\nAll outputs are IDENTICAL!") + return True + + +if __name__ == "__main__": + success = test_kernel() + + print("\n" + "=" * 80) + if success: + print("ALL TESTS PASSED!") + sys.exit(0) + else: + print("TESTS FAILED") + sys.exit(1) diff --git a/torchao/prototype/moe_training/kernels/mxfp8/quant.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py index 24915d6359..1b34a6647c 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/quant.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/quant.py @@ -499,7 +499,6 @@ def triton_mx_block_rearrange_2d_K_groups( Args: scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales - output_group_start_offsets: tensor of int32 values representing pre-computed group start indexes after blocked format padding Returns: - Rearranged tensor in block-scaled swizzle format """ @@ -522,8 +521,7 @@ def triton_mx_block_rearrange_2d_K_groups( BLOCK_ROWS, BLOCK_COLS = 128, 4 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS - # We parallelize per group and per row block. - # Cols per group is variable, so we just loop through col blocks for each group. + # Naive grid - only parallelize by group and row grid = lambda META: ( num_groups, num_row_blocks, @@ -709,9 +707,10 @@ def mxfp8_quantize_cuda_3d( torch.Tensor: scales tensor """ assert x.ndim == 3, "Input tensor must be 3D" - assert x.dtype in (torch.float32, torch.bfloat16), ( - "Input tensor must be float32 or bfloat16" - ) + assert x.dtype in ( + torch.float32, + torch.bfloat16, + ), "Input tensor must be float32 or bfloat16" q_data, scales = mxfp8_cuda.quantize_3d( x, scale_dim_n=block_size, scaling_mode=scaling_mode ) @@ -724,9 +723,10 @@ def _fake_mxfp8_quantize_cuda_3d( scaling_mode: str = "floor", ) -> Tuple[torch.Tensor, torch.Tensor]: assert x.ndim == 3, "Input tensor must be 3D" - assert x.dtype in (torch.float32, torch.bfloat16), ( - "Input tensor must be float32 or bfloat16" - ) + assert x.dtype in ( + torch.float32, + torch.bfloat16, + ), "Input tensor must be float32 or bfloat16" E, N, K = x.shape # Quantized tensor is in column major layouts q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided( diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 3a4ad43b4f..4cb6525c16 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -26,13 +26,13 @@ _is_column_major, ) from torchao.prototype.mx_formats.config import ( + KernelPreference, MXFP8Dim1CastKernelChoice, ScaleCalculationMode, ) from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper -from torchao.quantization.quantize_.common import KernelPreference logger: logging.Logger = logging.getLogger(__name__) @@ -412,7 +412,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=grad_out.dtype, - kernel_preference=KernelPreference.AUTO, # Not used + kernel_preference=KernelPreference.AUTO, cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -428,7 +428,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=A.dtype, - kernel_preference=KernelPreference.AUTO, # Not used + kernel_preference=KernelPreference.AUTO, cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -475,7 +475,7 @@ def _to_mxfp8_dim1_3d( block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=B_reshaped.dtype, - kernel_preference=KernelPreference.AUTO, # Not used + kernel_preference=KernelPreference.AUTO, cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scaling_mode, ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index b4cd192244..72a19e2c86 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -626,9 +626,10 @@ def triton_mxfp8_dequant_dim0( scale_block_size: int = 32, ) -> torch.Tensor: assert scale_block_size == 32, "scale_block_size must be 32 for now" - assert out_dtype in (torch.bfloat16, torch.float32), ( - "out_dtype must be bf16 or fp32" - ) + assert out_dtype in ( + torch.bfloat16, + torch.float32, + ), "out_dtype must be bf16 or fp32" # Input shape must be 2D. orig_shape = e4m3_data.shape @@ -1055,6 +1056,7 @@ def _(scale_tensor): padded_cols = n_col_blocks * 4 return scale_tensor.new_empty((padded_rows, padded_cols)) + else: def triton_to_mxfp8_dim0( @@ -1216,6 +1218,7 @@ def custom_mxfp8_quantize_cuda_dim1_sharding( rule_for_input_sharded_dim1, ] return acceptable_shardings + else: def mxfp8_quantize_cuda(