Skip to content

Commit 82ed1e6

Browse files
committed
Restore 6.23 tok/s performance in WMMA HGRP kernel while retaining FP8
accuracy.
1 parent d25442a commit 82ed1e6

File tree

1 file changed

+47
-52
lines changed

1 file changed

+47
-52
lines changed

ggml/src/ggml-cuda/indexer-fused.cu

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ __global__ void k_indexer_logits_wmma16_f32_hgrp(
887887
const int * __restrict__ starts,
888888
const int * __restrict__ ends,
889889
float * __restrict__ Out) {
890-
#if __CUDA_ARCH__ >= 900
890+
#if __CUDA_ARCH__ >= 800
891891
const int tokens_per_tile = 1;
892892
const int t0 = blockIdx.x * tokens_per_tile;
893893
const int k0 = blockIdx.y * 16;
@@ -923,74 +923,69 @@ __global__ void k_indexer_logits_wmma16_f32_hgrp(
923923
}
924924
__syncthreads();
925925

926-
// Shared buffers for one K-block (we'll use K_block=32) and FP32 accumulators
927-
const int K_block = 32;
928-
__shared__ uint8_t A_fp8[16 * K_block]; // K tile, FP8 E4M3
929-
__shared__ uint8_t B_fp8[16 * K_block]; // Q tile, FP8 E4M3
930-
__shared__ float C_sh[16 * 16]; // accumulator
931-
__shared__ float S_acc[16];
926+
__shared__ __half A_sh[16*16]; // row-major K tile (FP8-quantized then decoded)
927+
__shared__ __half B_sh[16*16]; // col-major Q tile (FP8-quantized then decoded)
928+
__shared__ float C_sh[16*16]; // accumulator dump
929+
__shared__ float S_acc[16]; // accumulate per kv row
930+
932931
if (threadIdx.x < 16) S_acc[threadIdx.x] = 0.0f;
933932
__syncthreads();
934933

935934
for (int h0 = 0; h0 < H; h0 += 16) {
936-
// Zero per-group accum
937-
for (int i = lane; i < 16 * 16; i += 32) C_sh[i] = 0.0f;
938-
__syncthreads();
939-
940-
// Iterate D in K_block chunks
941-
for (int d0 = 0; d0 < D; d0 += K_block) {
942-
int curK = min(K_block, D - d0);
943-
// FP8-encode K into A_fp8
944-
for (int idx = lane; idx < 16 * curK; idx += 32) {
945-
int mi = idx / curK;
946-
int di = idx % curK;
935+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
936+
wmma::fill_fragment(c_frag, 0.0f);
937+
938+
// Iterate K dimension in 16-slices
939+
for (int d0 = 0; d0 < D; d0 += 16) {
940+
int lane2 = threadIdx.x & 31;
941+
// Load A_sh: rows are kv rows, cols are k-slice, with FP8 quant/dequant and per-row scale
942+
for (int idx = lane2; idx < 16*16; idx += 32) {
943+
int mi = idx / 16; // row
944+
int di = idx % 16; // col
947945
int kv_idx = k0 + mi;
948-
uint8_t code = 0;
949-
if (kv_idx < kv) {
950-
float f = 0.0f;
951-
if (d0 + di < D) {
952-
f = K[(size_t)(d0 + di) + (size_t)D * (size_t)kv_idx];
953-
float sf = K_sf[mi];
954-
float scaled = f / sf;
955-
code = f32_to_fp8e4m3(scaled);
956-
}
946+
__half v = __float2half_rn(0.0f);
947+
if (kv_idx < kv && d0 + di < D) {
948+
float f = K[(size_t)(d0 + di) + (size_t)D * (size_t)kv_idx];
949+
float sf = K_sf[mi];
950+
float scaled = f / sf;
951+
uint8_t code = f32_to_fp8e4m3(scaled);
952+
float dec = fp8e4m3_to_f32(code);
953+
v = __float2half_rn(dec);
957954
}
958-
A_fp8[mi * curK + di] = code;
955+
A_sh[mi * 16 + di] = v;
959956
}
960-
// FP8-encode Q into B_fp8
961-
for (int idx = lane; idx < 16 * curK; idx += 32) {
962-
int di = idx % curK;
963-
int cj = idx / curK; // 0..15 heads in group
957+
// Load B_sh: columns=16 heads in group, rows=16 k-slice; col-major, FP8 quant/dequant
958+
for (int idx = lane2; idx < 16*16; idx += 32) {
959+
int di = idx / 16; // k index
960+
int cj = idx % 16; // head col 0..15
964961
int h = h0 + cj;
965-
int tok = t0;
966-
uint8_t code = 0;
962+
int tok = t0; // one token per tile
963+
__half v = __float2half_rn(0.0f);
967964
if (tok < Tc && h < H && d0 + di < D) {
968965
float f = Q[(size_t)(d0 + di) + (size_t)D * (size_t)(tok*H + h)];
969-
code = f32_to_fp8e4m3(f);
966+
uint8_t code = f32_to_fp8e4m3(f);
967+
float dec = fp8e4m3_to_f32(code);
968+
v = __float2half_rn(dec);
970969
}
971-
B_fp8[cj * curK + di] = code;
970+
B_sh[cj * 16 + di] = v;
972971
}
973972
__syncthreads();
974973

975-
// Naive FP8 matmul: accumulate into C_sh (16x16) in FP32
976-
for (int mi = lane; mi < 16; mi += 32) {
977-
for (int cj = 0; cj < 16; ++cj) {
978-
float acc = 0.0f;
979-
for (int di = 0; di < curK; ++di) {
980-
uint8_t a = A_fp8[mi * curK + di];
981-
uint8_t b = B_fp8[cj * curK + di];
982-
float fa = fp8e4m3_to_f32(a);
983-
float fb = fp8e4m3_to_f32(b);
984-
acc += fa * fb;
985-
}
986-
C_sh[mi * 16 + cj] += acc;
987-
}
988-
}
974+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag;
975+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> b_frag;
976+
wmma::load_matrix_sync(a_frag, A_sh, 16);
977+
wmma::load_matrix_sync(b_frag, B_sh, 16);
978+
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
989979
__syncthreads();
990980
}
991981

992-
// Post-process: ReLU + head weights into S_acc
993-
for (int mi = lane; mi < 16; mi += 32) {
982+
// Dump accumulators to shared
983+
wmma::store_matrix_sync(C_sh, c_frag, 16, wmma::mem_row_major);
984+
__syncthreads();
985+
986+
// Accumulate this head-group contribution into S_acc per row
987+
int lane3 = threadIdx.x & 31;
988+
for (int mi = lane3; mi < 16; mi += 32) {
994989
float srow = 0.0f;
995990
for (int cj = 0; cj < 16; ++cj) {
996991
float v = C_sh[mi * 16 + cj];

0 commit comments

Comments
 (0)