Skip to content

Commit bbecb21

Browse files
committed
Vendored DeepGEMM compiling successfully. Test fails though.
1 parent b82a9d6 commit bbecb21

File tree

3 files changed

+616
-1
lines changed

3 files changed

+616
-1
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ if (CUDAToolkit_FOUND)
7272
${GGML_SOURCES_CUDA}
7373
)
7474

75+
# DeepGEMM FP8 paged MQA logits (sm100) uses constexpr helpers in device code;
76+
# enable relaxed constexpr just for the TU that includes the DeepGEMM kernels.
77+
set_source_files_properties(indexer-fused.cu PROPERTIES
78+
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>"
79+
)
80+
7581

7682
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "mqa_attn_return_logits_kernel\\.cu$")
7783

ggml/src/ggml-cuda/indexer-fused.cu

Lines changed: 282 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,62 @@
77
using namespace nvcuda;
88

99
#include <cuda_runtime.h>
10+
#include <cuda.h>
11+
#include <cute/arch/copy_sm90_desc.hpp>
12+
#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh>
1013

1114
#include <cuda_pipeline_primitives.h>
15+
16+
#if CUDART_VERSION >= 12000
17+
static inline bool dg_fp8_encode_tma_2d(
18+
cute::TmaDescriptor &desc,
19+
CUtensorMapDataType type,
20+
void *base,
21+
uint32_t inner_dim, uint32_t outer_dim,
22+
uint32_t box_inner, uint32_t box_outer,
23+
uint32_t outer_stride_elems,
24+
size_t elem_size) {
25+
cuuint64_t dims[2] = { (cuuint64_t) inner_dim, (cuuint64_t) outer_dim };
26+
cuuint64_t strides[1] = { (cuuint64_t) outer_stride_elems * (cuuint64_t) elem_size };
27+
cuuint32_t box[2] = { (cuuint32_t) box_inner, (cuuint32_t) box_outer };
28+
cuuint32_t elem_strides[2] = { 1u, 1u };
29+
CUresult res = cuTensorMapEncodeTiled(
30+
reinterpret_cast<CUtensorMap*>(&desc), type,
31+
2u, base, dims, strides, box, elem_strides,
32+
CU_TENSOR_MAP_INTERLEAVE_NONE,
33+
CU_TENSOR_MAP_SWIZZLE_NONE,
34+
CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
35+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
36+
return res == CUDA_SUCCESS;
37+
}
38+
39+
static inline bool dg_fp8_encode_tma_3d(
40+
cute::TmaDescriptor &desc,
41+
CUtensorMapDataType type,
42+
void *base,
43+
uint32_t dim0, uint32_t dim1, uint32_t dim2,
44+
uint32_t box0, uint32_t box1, uint32_t box2,
45+
uint32_t stride0_elems, uint32_t stride1_elems,
46+
size_t elem_size) {
47+
cuuint64_t dims[3] = { (cuuint64_t) dim0, (cuuint64_t) dim1, (cuuint64_t) dim2 };
48+
cuuint64_t strides[2] = {
49+
(cuuint64_t) stride0_elems * (cuuint64_t) elem_size,
50+
(cuuint64_t) stride1_elems * (cuuint64_t) elem_size
51+
};
52+
cuuint32_t box[3] = { (cuuint32_t) box0, (cuuint32_t) box1, (cuuint32_t) box2 };
53+
cuuint32_t elem_strides[3] = { 1u, 1u, 1u };
54+
CUresult res = cuTensorMapEncodeTiled(
55+
reinterpret_cast<CUtensorMap*>(&desc), type,
56+
3u, base, dims, strides, box, elem_strides,
57+
CU_TENSOR_MAP_INTERLEAVE_NONE,
58+
CU_TENSOR_MAP_SWIZZLE_NONE,
59+
CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
60+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
61+
return res == CUDA_SUCCESS;
62+
}
63+
#endif // CUDART_VERSION >= 12000
64+
65+
1266
#include <mma.h>
1367
#include <stdint.h>
1468
#include <stdio.h>
@@ -1145,11 +1199,238 @@ extern "C" void ggml_cuda_indexer_logits_fused_device(ggml_backend_cuda_context
11451199
}
11461200
}
11471201

1202+
const char *dg_env = getenv("LLAMA_DG_FP8");
1203+
const bool use_dg_fp8 = (dg_env && *dg_env && atoi(dg_env) != 0);
11481204
if (sparse_debug_on()) printf("[INDEXER_DISPATCH] use_wmma=%d D=%d H=%d Tc=%d kv=%d BLOCK_Q=%d BLOCK_N=%d D_TILE=%d\n", (int)use_wmma, D, H, Tc, kv_end, BLOCK_Q, BLOCK_N, D_TILE);
11491205
// Optional: TL port path in device wrapper
11501206
const char * __prof_env = getenv("LLAMA_SPARSE_PROF");
11511207
auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH");
1152-
if (const char *s = getenv("LLAMA_INDEXER_TL_PORT"); s && atoi(s) != 0) {
1208+
if (use_dg_fp8) {
1209+
if (sparse_debug_on()) fprintf(stderr, "[INDEXER_DISPATCH] using DeepGEMM FP8 paged MQA logits path\n");
1210+
1211+
#if CUDART_VERSION >= 12000
1212+
// DeepSeek V3.2-Exp: use DeepGEMM FP8 paged MQA logits kernel when shapes match
1213+
const int block_kv = 64;
1214+
const int num_math_warp_groups = 4;
1215+
const int num_specialized_threads = 128;
1216+
const int num_math_threads = num_math_warp_groups * 128;
1217+
if (H == 64 && (D == 64 || D == 128) && (Tc == 1 || Tc == 2) && kv_end % block_kv == 0) {
1218+
int batch_size = 1;
1219+
int next_n = Tc;
1220+
int num_heads = H;
1221+
int head_dim = D;
1222+
int num_kv_blocks = kv_end / block_kv;
1223+
int max_context_len = kv_end;
1224+
1225+
ggml_cuda_pool & pool = ctx.pool(ggml_cuda_get_device());
1226+
1227+
// Build DeepGEMM-style context_lens (1D) and block_table
1228+
ggml_cuda_pool_alloc<unsigned int> __ctx_lens(pool, batch_size);
1229+
ggml_cuda_pool_alloc<unsigned int> __block_tbl(pool, (size_t)batch_size * (size_t)num_kv_blocks);
1230+
unsigned int *d_ctx_lens = __ctx_lens.get();
1231+
unsigned int *d_block_tbl = __block_tbl.get();
1232+
CUDA_CHECK(cudaMemsetAsync(d_block_tbl, 0, sizeof(unsigned int)*(size_t)batch_size*(size_t)num_kv_blocks, stream));
1233+
unsigned int h_ctx[1] = { (unsigned int) kv_end };
1234+
CUDA_CHECK(cudaMemcpyAsync(d_ctx_lens, h_ctx, sizeof(unsigned int), cudaMemcpyHostToDevice, stream));
1235+
// block_table[0][i] = i
1236+
CUDA_CHECK(cudaMemsetAsync(d_block_tbl, 0, sizeof(unsigned int)*(size_t)num_kv_blocks, stream));
1237+
{
1238+
// simple host init for block_table
1239+
std::vector<unsigned int> h_bt(num_kv_blocks);
1240+
for (int i = 0; i < num_kv_blocks; ++i) h_bt[i] = (unsigned int) i;
1241+
CUDA_CHECK(cudaMemcpyAsync(d_block_tbl, h_bt.data(), sizeof(unsigned int)*h_bt.size(), cudaMemcpyHostToDevice, stream));
1242+
}
1243+
1244+
// Build schedule_meta on host mirroring DeepGEMM scheduler
1245+
int dev = ggml_cuda_get_device();
1246+
int num_sms = 0;
1247+
CUDA_CHECK(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev));
1248+
ggml_cuda_pool_alloc<unsigned int> __sched(pool, (size_t)(num_sms + 1) * 2);
1249+
unsigned int *d_sched = __sched.get();
1250+
std::vector<unsigned int> h_sched((size_t)(num_sms + 1) * 2, 0);
1251+
{
1252+
const int aligned_batch_size = ((batch_size + 31) / 32) * 32;
1253+
const int split_kv = block_kv * num_math_warp_groups;
1254+
// For our single-q case, num_segs for q=0 is ceil_div(context_len, split_kv)
1255+
int context_len = kv_end;
1256+
int num_segs = (context_len + split_kv - 1) / split_kv;
1257+
unsigned int total_segs = (unsigned int) num_segs;
1258+
unsigned int q = total_segs / (unsigned int) num_sms;
1259+
unsigned int r = total_segs % (unsigned int) num_sms;
1260+
for (int sm = 0; sm <= num_sms; ++sm) {
1261+
unsigned int seg_starts = (unsigned int) sm * q + (sm < (int)r ? (unsigned int) sm : (unsigned int) r);
1262+
unsigned int q_idx = (seg_starts == 0 ? 0u : (unsigned int) batch_size);
1263+
unsigned int kv_split_idx = (seg_starts == 0 ? 0u : (seg_starts - 1));
1264+
h_sched[(size_t)sm*2 + 0] = q_idx;
1265+
h_sched[(size_t)sm*2 + 1] = kv_split_idx;
1266+
}
1267+
}
1268+
CUDA_CHECK(cudaMemcpyAsync(d_sched, h_sched.data(), sizeof(unsigned int)*h_sched.size(), cudaMemcpyHostToDevice, stream));
1269+
1270+
// Allocate logits [batch*next_n, aligned_max_context_len]
1271+
const int aligned_max_context_len = ((max_context_len + num_math_warp_groups*block_kv - 1) / (num_math_warp_groups*block_kv))*(num_math_warp_groups*block_kv);
1272+
ggml_cuda_pool_alloc<float> __dg_logits(pool, (size_t)batch_size * (size_t)next_n * (size_t)aligned_max_context_len);
1273+
float *d_dg_logits = __dg_logits.get();
1274+
CUDA_CHECK(cudaMemsetAsync(d_dg_logits, 0, sizeof(float)*(size_t)batch_size*(size_t)next_n*(size_t)aligned_max_context_len, stream));
1275+
1276+
// Build FP8 Q, K and scales in row-major and flatten
1277+
ggml_cuda_pool_alloc<float> __Qrm(pool, (size_t)(Tc*H) * (size_t)D);
1278+
ggml_cuda_pool_alloc<float> __Krm(pool, (size_t)kv_end * (size_t)D);
1279+
ggml_cuda_pool_alloc<float> __Wrm(pool, (size_t)Tc * (size_t)H);
1280+
float *dQrm = __Qrm.get();
1281+
float *dKrm = __Krm.get();
1282+
float *dWrm = __Wrm.get();
1283+
dim3 tbT(32, 8);
1284+
dim3 gdQ((Tc*H + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y);
1285+
dim3 gdK((kv_end + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y);
1286+
dim3 gdW((Tc + tbT.x - 1)/tbT.x, (H + tbT.y - 1)/tbT.y);
1287+
k_colmajor_DN_to_rowmajor_ND<<<gdQ, tbT, 0, stream>>>(dQ, D, Tc*H, dQrm);
1288+
k_colmajor_DN_to_rowmajor_ND<<<gdK, tbT, 0, stream>>>(dK, D, kv_end, dKrm);
1289+
k_colmajor_DN_to_rowmajor_ND<<<gdW, tbT, 0, stream>>>(dW, H, Tc, dWrm);
1290+
1291+
// FP8 Q (no per-row scaling)
1292+
ggml_cuda_pool_alloc<unsigned char> __Qfp8(pool, (size_t)(Tc*H) * (size_t)D);
1293+
unsigned char *dQfp8 = __Qfp8.get();
1294+
{
1295+
size_t total = (size_t)(Tc*H) * (size_t)D;
1296+
dim3 tb(256);
1297+
dim3 gd((unsigned)((total + tb.x - 1)/tb.x));
1298+
k_rowmajor_f32_to_fp8_e4m3<<<gd, tb, 0, stream>>>(dQrm, (int)(Tc*H), D, dQfp8);
1299+
}
1300+
1301+
// FP8 K with per-row scaling and combined k_scale
1302+
ggml_cuda_pool_alloc<unsigned char> __Kfp8(pool, (size_t)kv_end * (size_t)D);
1303+
ggml_cuda_pool_alloc<float> __Kamax(pool, (size_t)kv_end);
1304+
ggml_cuda_pool_alloc<float> __Ksf(pool, (size_t)kv_end);
1305+
ggml_cuda_pool_alloc<float> __KsfInv(pool, (size_t)kv_end);
1306+
ggml_cuda_pool_alloc<float> __IdxKScale(pool, (size_t)kv_end);
1307+
unsigned char *dKfp8 = __Kfp8.get();
1308+
float *dKamax = __Kamax.get();
1309+
float *dKsf = __Ksf.get();
1310+
float *dKsfInv= __KsfInv.get();
1311+
float *dIdxKScale = __IdxKScale.get();
1312+
{
1313+
int rowsK = kv_end;
1314+
int colsK = D;
1315+
int threadsA = 256;
1316+
int blocksA = (rowsK + threadsA - 1) / threadsA;
1317+
k_rowmajor_f32_rowwise_absmax<<<blocksA, threadsA, 0, stream>>>(dKrm, rowsK, colsK, dKamax);
1318+
k_fp8_compute_row_scales<<<blocksA, threadsA, 0, stream>>>(dKamax, rowsK, dKsf, dKsfInv);
1319+
size_t total = (size_t)rowsK * (size_t)colsK;
1320+
dim3 tb(256);
1321+
dim3 gd((unsigned)((total + tb.x - 1)/tb.x));
1322+
k_rowmajor_f32_to_fp8_e4m3_rowwise_scaled<<<gd, tb, 0, stream>>>(dKrm, rowsK, colsK, dKsfInv, dKfp8);
1323+
k_elemwise_mul<<<blocksA, threadsA, 0, stream>>>(dKS, dKsf, dIdxKScale, rowsK);
1324+
}
1325+
1326+
// Build fused KV cache buffer [num_kv_blocks, block_kv, 1, head_dim+4]
1327+
const int head_dim_with_sf = head_dim + 4;
1328+
ggml_cuda_pool_alloc<unsigned char> __FusedKV(pool, (size_t)num_kv_blocks * (size_t)block_kv * (size_t)head_dim_with_sf);
1329+
unsigned char *dFusedKV = __FusedKV.get();
1330+
CUDA_CHECK(cudaMemsetAsync(dFusedKV, 0, (size_t)num_kv_blocks * (size_t)block_kv * (size_t)head_dim_with_sf, stream));
1331+
{
1332+
// Layout: [num_blocks, block_kv, 1, head_dim+4]
1333+
dim3 tb(256);
1334+
dim3 gd((unsigned)((kv_end * head_dim + tb.x - 1)/tb.x));
1335+
// pack Kfp8 into leading head_dim slots; then we will separately feed scales via TMA
1336+
// here we only lay out FP8 values row-major
1337+
// index: row r in [0,kv_end), col c in [0,head_dim)
1338+
auto pack = [=] __device__ (int idx) {};
1339+
}
1340+
// For DeepGEMM kernel we actually pass Kfp8 and scales separately via tensor_map_kv and tensor_map_kv_scales,
1341+
// so we do not need a true fused buffer here; instead we treat dKfp8 as [num_kv_blocks, block_kv, head_dim]
1342+
1343+
// Create TMA descriptors using driver API wrappers
1344+
cute::TmaDescriptor tma_q{}, tma_kv{}, tma_kv_scales{}, tma_w{};
1345+
// Q: [head_dim, batch*next_n*num_heads]
1346+
dg_fp8_encode_tma_2d(tma_q, CU_TENSOR_MAP_DATA_TYPE_UINT8,
1347+
dQfp8,
1348+
(uint32_t) head_dim,
1349+
(uint32_t)(batch_size * next_n * num_heads),
1350+
(uint32_t) head_dim,
1351+
(uint32_t)(next_n * num_heads),
1352+
(uint32_t) head_dim,
1353+
sizeof(unsigned char));
1354+
// KV: [head_dim, block_kv, num_kv_blocks]
1355+
dg_fp8_encode_tma_3d(tma_kv, CU_TENSOR_MAP_DATA_TYPE_UINT8,
1356+
dKfp8,
1357+
(uint32_t) head_dim,
1358+
(uint32_t) block_kv,
1359+
(uint32_t) num_kv_blocks,
1360+
(uint32_t) head_dim,
1361+
(uint32_t) block_kv,
1362+
1u,
1363+
(uint32_t) head_dim,
1364+
(uint32_t)(head_dim * block_kv),
1365+
sizeof(unsigned char));
1366+
// KV scales: [block_kv, num_kv_blocks]
1367+
dg_fp8_encode_tma_2d(tma_kv_scales, CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1368+
dIdxKScale,
1369+
(uint32_t) block_kv,
1370+
(uint32_t) num_kv_blocks,
1371+
(uint32_t) block_kv,
1372+
1u,
1373+
(uint32_t) block_kv,
1374+
sizeof(float));
1375+
// Weights: [next_n*num_heads, batch_size]
1376+
dg_fp8_encode_tma_2d(tma_w, CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1377+
dWrm,
1378+
(uint32_t)(next_n * num_heads),
1379+
(uint32_t) batch_size,
1380+
(uint32_t)(next_n * num_heads),
1381+
1u,
1382+
(uint32_t)(next_n * num_heads),
1383+
sizeof(float));
1384+
1385+
// Launch DeepGEMM kernel
1386+
dim3 grid(num_sms, 1, 1);
1387+
dim3 block(num_specialized_threads + num_math_threads + 128, 1, 1);
1388+
const uint32_t kNextN = 2;
1389+
const uint32_t kNumHeads = 64;
1390+
const uint32_t kHeadDim = 128;
1391+
const uint32_t kNumQStages = 3;
1392+
const uint32_t kNumKVStages = 3;
1393+
const uint32_t SPLIT_KV = block_kv * num_math_warp_groups;
1394+
const uint32_t kNumSpecializedThreads = 128;
1395+
const uint32_t kNumMathThreads = num_math_warp_groups * 128;
1396+
size_t shmem_bytes = 0; // computed in-kernel
1397+
1398+
LAUNCH_PROFILE_KERNEL("PROFILE_DG_FP8", DG_FP8, stream, ([&](){
1399+
deep_gemm::sm100_fp8_paged_mqa_logits<
1400+
kNextN, kNumHeads,
1401+
kHeadDim, (uint32_t) block_kv,
1402+
false,
1403+
kNumQStages, kNumKVStages,
1404+
SPLIT_KV,
1405+
kNumSpecializedThreads, kNumMathThreads
1406+
><<<grid, block, shmem_bytes, stream>>>(
1407+
(uint32_t) batch_size,
1408+
(uint64_t) aligned_max_context_len,
1409+
(uint64_t) num_kv_blocks,
1410+
d_ctx_lens,
1411+
d_dg_logits,
1412+
d_block_tbl,
1413+
d_sched,
1414+
tma_q,
1415+
tma_kv,
1416+
tma_kv_scales,
1417+
tma_w);
1418+
})(), D, H, Tc, kv_end);
1419+
1420+
// Map logits [batch*next_n, max_context_len] back to Out [kv, Tc]
1421+
// For our simple case (batch=1), token t in [0,Tc), kv index k in [0,kv_end)
1422+
dim3 tb(32, 4);
1423+
dim3 gd((kv_end + tb.x - 1)/tb.x, (Tc + tb.y - 1)/tb.y);
1424+
k_transpose_TcKv_to_KvTc<<<gd, tb, 0, stream>>>(d_dg_logits, Tc, kv_end, dOut);
1425+
CUDA_CHECK(cudaGetLastError());
1426+
cudaStreamSynchronize(stream);
1427+
if (dStarts_tmp) cudaFree(dStarts_tmp);
1428+
if (dEnds_tmp) cudaFree(dEnds_tmp);
1429+
return;
1430+
}
1431+
#endif // CUDART_VERSION >= 12000
1432+
} else if (const char *s = getenv("LLAMA_INDEXER_TL_PORT"); s && atoi(s) != 0) {
1433+
11531434
ggml_cuda_pool & __pool = ctx.pool(ggml_cuda_get_device());
11541435
bool use_tma_fp8 = false;
11551436
if (const char *e = getenv("LLAMA_TL_FP8"); e && atoi(e) != 0) use_tma_fp8 = true;

0 commit comments

Comments
 (0)