@@ -887,7 +887,7 @@ __global__ void k_indexer_logits_wmma16_f32_hgrp(
887887 const int * __restrict__ starts,
888888 const int * __restrict__ ends,
889889 float * __restrict__ Out) {
890- #if __CUDA_ARCH__ >= 900
890+ #if __CUDA_ARCH__ >= 800
891891 const int tokens_per_tile = 1 ;
892892 const int t0 = blockIdx .x * tokens_per_tile;
893893 const int k0 = blockIdx .y * 16 ;
@@ -923,74 +923,69 @@ __global__ void k_indexer_logits_wmma16_f32_hgrp(
923923 }
924924 __syncthreads ();
925925
926- // Shared buffers for one K-block (we'll use K_block=32) and FP32 accumulators
927- const int K_block = 32 ;
928- __shared__ uint8_t A_fp8[16 * K_block]; // K tile, FP8 E4M3
929- __shared__ uint8_t B_fp8[16 * K_block]; // Q tile, FP8 E4M3
930- __shared__ float C_sh[16 * 16 ]; // accumulator
931- __shared__ float S_acc[16 ];
926+ __shared__ __half A_sh[16 *16 ]; // row-major K tile (FP8-quantized then decoded)
927+ __shared__ __half B_sh[16 *16 ]; // col-major Q tile (FP8-quantized then decoded)
928+ __shared__ float C_sh[16 *16 ]; // accumulator dump
929+ __shared__ float S_acc[16 ]; // accumulate per kv row
930+
932931 if (threadIdx .x < 16 ) S_acc[threadIdx .x ] = 0 .0f ;
933932 __syncthreads ();
934933
935934 for (int h0 = 0 ; h0 < H; h0 += 16 ) {
936- // Zero per-group accum
937- for (int i = lane; i < 16 * 16 ; i += 32 ) C_sh[i] = 0 .0f ;
938- __syncthreads ();
939-
940- // Iterate D in K_block chunks
941- for (int d0 = 0 ; d0 < D; d0 += K_block) {
942- int curK = min (K_block, D - d0);
943- // FP8-encode K into A_fp8
944- for (int idx = lane; idx < 16 * curK; idx += 32 ) {
945- int mi = idx / curK;
946- int di = idx % curK;
935+ wmma::fragment<wmma::accumulator, 16 , 16 , 16 , float > c_frag;
936+ wmma::fill_fragment (c_frag, 0 .0f );
937+
938+ // Iterate K dimension in 16-slices
939+ for (int d0 = 0 ; d0 < D; d0 += 16 ) {
940+ int lane2 = threadIdx .x & 31 ;
941+ // Load A_sh: rows are kv rows, cols are k-slice, with FP8 quant/dequant and per-row scale
942+ for (int idx = lane2; idx < 16 *16 ; idx += 32 ) {
943+ int mi = idx / 16 ; // row
944+ int di = idx % 16 ; // col
947945 int kv_idx = k0 + mi;
948- uint8_t code = 0 ;
949- if (kv_idx < kv) {
950- float f = 0 .0f ;
951- if (d0 + di < D) {
952- f = K[(size_t )(d0 + di) + (size_t )D * (size_t )kv_idx];
953- float sf = K_sf[mi];
954- float scaled = f / sf;
955- code = f32_to_fp8e4m3 (scaled);
956- }
946+ __half v = __float2half_rn (0 .0f );
947+ if (kv_idx < kv && d0 + di < D) {
948+ float f = K[(size_t )(d0 + di) + (size_t )D * (size_t )kv_idx];
949+ float sf = K_sf[mi];
950+ float scaled = f / sf;
951+ uint8_t code = f32_to_fp8e4m3 (scaled);
952+ float dec = fp8e4m3_to_f32 (code);
953+ v = __float2half_rn (dec);
957954 }
958- A_fp8 [mi * curK + di] = code ;
955+ A_sh [mi * 16 + di] = v ;
959956 }
960- // FP8-encode Q into B_fp8
961- for (int idx = lane ; idx < 16 * curK ; idx += 32 ) {
962- int di = idx % curK;
963- int cj = idx / curK ; // 0..15 heads in group
957+ // Load B_sh: columns=16 heads in group, rows=16 k-slice; col-major, FP8 quant/dequant
958+ for (int idx = lane2 ; idx < 16 * 16 ; idx += 32 ) {
959+ int di = idx / 16 ; // k index
960+ int cj = idx % 16 ; // head col 0..15
964961 int h = h0 + cj;
965- int tok = t0;
966- uint8_t code = 0 ;
962+ int tok = t0; // one token per tile
963+ __half v = __float2half_rn ( 0 . 0f ) ;
967964 if (tok < Tc && h < H && d0 + di < D) {
968965 float f = Q[(size_t )(d0 + di) + (size_t )D * (size_t )(tok*H + h)];
969- code = f32_to_fp8e4m3 (f);
966+ uint8_t code = f32_to_fp8e4m3 (f);
967+ float dec = fp8e4m3_to_f32 (code);
968+ v = __float2half_rn (dec);
970969 }
971- B_fp8 [cj * curK + di] = code ;
970+ B_sh [cj * 16 + di] = v ;
972971 }
973972 __syncthreads ();
974973
975- // Naive FP8 matmul: accumulate into C_sh (16x16) in FP32
976- for (int mi = lane; mi < 16 ; mi += 32 ) {
977- for (int cj = 0 ; cj < 16 ; ++cj) {
978- float acc = 0 .0f ;
979- for (int di = 0 ; di < curK; ++di) {
980- uint8_t a = A_fp8[mi * curK + di];
981- uint8_t b = B_fp8[cj * curK + di];
982- float fa = fp8e4m3_to_f32 (a);
983- float fb = fp8e4m3_to_f32 (b);
984- acc += fa * fb;
985- }
986- C_sh[mi * 16 + cj] += acc;
987- }
988- }
974+ wmma::fragment<wmma::matrix_a, 16 , 16 , 16 , __half, wmma::row_major> a_frag;
975+ wmma::fragment<wmma::matrix_b, 16 , 16 , 16 , __half, wmma::col_major> b_frag;
976+ wmma::load_matrix_sync (a_frag, A_sh, 16 );
977+ wmma::load_matrix_sync (b_frag, B_sh, 16 );
978+ wmma::mma_sync (c_frag, a_frag, b_frag, c_frag);
989979 __syncthreads ();
990980 }
991981
992- // Post-process: ReLU + head weights into S_acc
993- for (int mi = lane; mi < 16 ; mi += 32 ) {
982+ // Dump accumulators to shared
983+ wmma::store_matrix_sync (C_sh, c_frag, 16 , wmma::mem_row_major);
984+ __syncthreads ();
985+
986+ // Accumulate this head-group contribution into S_acc per row
987+ int lane3 = threadIdx .x & 31 ;
988+ for (int mi = lane3; mi < 16 ; mi += 32 ) {
994989 float srow = 0 .0f ;
995990 for (int cj = 0 ; cj < 16 ; ++cj) {
996991 float v = C_sh[mi * 16 + cj];
0 commit comments