|
7 | 7 | using namespace nvcuda; |
8 | 8 |
|
9 | 9 | #include <cuda_runtime.h> |
| 10 | +#include <cuda.h> |
| 11 | +#include <cute/arch/copy_sm90_desc.hpp> |
| 12 | +#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh> |
10 | 13 |
|
11 | 14 | #include <cuda_pipeline_primitives.h> |
| 15 | + |
| 16 | +#if CUDART_VERSION >= 12000 |
| 17 | +static inline bool dg_fp8_encode_tma_2d( |
| 18 | + cute::TmaDescriptor &desc, |
| 19 | + CUtensorMapDataType type, |
| 20 | + void *base, |
| 21 | + uint32_t inner_dim, uint32_t outer_dim, |
| 22 | + uint32_t box_inner, uint32_t box_outer, |
| 23 | + uint32_t outer_stride_elems, |
| 24 | + size_t elem_size) { |
| 25 | + cuuint64_t dims[2] = { (cuuint64_t) inner_dim, (cuuint64_t) outer_dim }; |
| 26 | + cuuint64_t strides[1] = { (cuuint64_t) outer_stride_elems * (cuuint64_t) elem_size }; |
| 27 | + cuuint32_t box[2] = { (cuuint32_t) box_inner, (cuuint32_t) box_outer }; |
| 28 | + cuuint32_t elem_strides[2] = { 1u, 1u }; |
| 29 | + CUresult res = cuTensorMapEncodeTiled( |
| 30 | + reinterpret_cast<CUtensorMap*>(&desc), type, |
| 31 | + 2u, base, dims, strides, box, elem_strides, |
| 32 | + CU_TENSOR_MAP_INTERLEAVE_NONE, |
| 33 | + CU_TENSOR_MAP_SWIZZLE_NONE, |
| 34 | + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, |
| 35 | + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); |
| 36 | + return res == CUDA_SUCCESS; |
| 37 | +} |
| 38 | + |
| 39 | +static inline bool dg_fp8_encode_tma_3d( |
| 40 | + cute::TmaDescriptor &desc, |
| 41 | + CUtensorMapDataType type, |
| 42 | + void *base, |
| 43 | + uint32_t dim0, uint32_t dim1, uint32_t dim2, |
| 44 | + uint32_t box0, uint32_t box1, uint32_t box2, |
| 45 | + uint32_t stride0_elems, uint32_t stride1_elems, |
| 46 | + size_t elem_size) { |
| 47 | + cuuint64_t dims[3] = { (cuuint64_t) dim0, (cuuint64_t) dim1, (cuuint64_t) dim2 }; |
| 48 | + cuuint64_t strides[2] = { |
| 49 | + (cuuint64_t) stride0_elems * (cuuint64_t) elem_size, |
| 50 | + (cuuint64_t) stride1_elems * (cuuint64_t) elem_size |
| 51 | + }; |
| 52 | + cuuint32_t box[3] = { (cuuint32_t) box0, (cuuint32_t) box1, (cuuint32_t) box2 }; |
| 53 | + cuuint32_t elem_strides[3] = { 1u, 1u, 1u }; |
| 54 | + CUresult res = cuTensorMapEncodeTiled( |
| 55 | + reinterpret_cast<CUtensorMap*>(&desc), type, |
| 56 | + 3u, base, dims, strides, box, elem_strides, |
| 57 | + CU_TENSOR_MAP_INTERLEAVE_NONE, |
| 58 | + CU_TENSOR_MAP_SWIZZLE_NONE, |
| 59 | + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, |
| 60 | + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); |
| 61 | + return res == CUDA_SUCCESS; |
| 62 | +} |
| 63 | +#endif // CUDART_VERSION >= 12000 |
| 64 | + |
| 65 | + |
12 | 66 | #include <mma.h> |
13 | 67 | #include <stdint.h> |
14 | 68 | #include <stdio.h> |
@@ -1145,11 +1199,238 @@ extern "C" void ggml_cuda_indexer_logits_fused_device(ggml_backend_cuda_context |
1145 | 1199 | } |
1146 | 1200 | } |
1147 | 1201 |
|
| 1202 | + const char *dg_env = getenv("LLAMA_DG_FP8"); |
| 1203 | + const bool use_dg_fp8 = (dg_env && *dg_env && atoi(dg_env) != 0); |
1148 | 1204 | if (sparse_debug_on()) printf("[INDEXER_DISPATCH] use_wmma=%d D=%d H=%d Tc=%d kv=%d BLOCK_Q=%d BLOCK_N=%d D_TILE=%d\n", (int)use_wmma, D, H, Tc, kv_end, BLOCK_Q, BLOCK_N, D_TILE); |
1149 | 1205 | // Optional: TL port path in device wrapper |
1150 | 1206 | const char * __prof_env = getenv("LLAMA_SPARSE_PROF"); |
1151 | 1207 | auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH"); |
1152 | | - if (const char *s = getenv("LLAMA_INDEXER_TL_PORT"); s && atoi(s) != 0) { |
| 1208 | + if (use_dg_fp8) { |
| 1209 | + if (sparse_debug_on()) fprintf(stderr, "[INDEXER_DISPATCH] using DeepGEMM FP8 paged MQA logits path\n"); |
| 1210 | + |
| 1211 | +#if CUDART_VERSION >= 12000 |
| 1212 | + // DeepSeek V3.2-Exp: use DeepGEMM FP8 paged MQA logits kernel when shapes match |
| 1213 | + const int block_kv = 64; |
| 1214 | + const int num_math_warp_groups = 4; |
| 1215 | + const int num_specialized_threads = 128; |
| 1216 | + const int num_math_threads = num_math_warp_groups * 128; |
| 1217 | + if (H == 64 && (D == 64 || D == 128) && (Tc == 1 || Tc == 2) && kv_end % block_kv == 0) { |
| 1218 | + int batch_size = 1; |
| 1219 | + int next_n = Tc; |
| 1220 | + int num_heads = H; |
| 1221 | + int head_dim = D; |
| 1222 | + int num_kv_blocks = kv_end / block_kv; |
| 1223 | + int max_context_len = kv_end; |
| 1224 | + |
| 1225 | + ggml_cuda_pool & pool = ctx.pool(ggml_cuda_get_device()); |
| 1226 | + |
| 1227 | + // Build DeepGEMM-style context_lens (1D) and block_table |
| 1228 | + ggml_cuda_pool_alloc<unsigned int> __ctx_lens(pool, batch_size); |
| 1229 | + ggml_cuda_pool_alloc<unsigned int> __block_tbl(pool, (size_t)batch_size * (size_t)num_kv_blocks); |
| 1230 | + unsigned int *d_ctx_lens = __ctx_lens.get(); |
| 1231 | + unsigned int *d_block_tbl = __block_tbl.get(); |
| 1232 | + CUDA_CHECK(cudaMemsetAsync(d_block_tbl, 0, sizeof(unsigned int)*(size_t)batch_size*(size_t)num_kv_blocks, stream)); |
| 1233 | + unsigned int h_ctx[1] = { (unsigned int) kv_end }; |
| 1234 | + CUDA_CHECK(cudaMemcpyAsync(d_ctx_lens, h_ctx, sizeof(unsigned int), cudaMemcpyHostToDevice, stream)); |
| 1235 | + // block_table[0][i] = i |
| 1236 | + CUDA_CHECK(cudaMemsetAsync(d_block_tbl, 0, sizeof(unsigned int)*(size_t)num_kv_blocks, stream)); |
| 1237 | + { |
| 1238 | + // simple host init for block_table |
| 1239 | + std::vector<unsigned int> h_bt(num_kv_blocks); |
| 1240 | + for (int i = 0; i < num_kv_blocks; ++i) h_bt[i] = (unsigned int) i; |
| 1241 | + CUDA_CHECK(cudaMemcpyAsync(d_block_tbl, h_bt.data(), sizeof(unsigned int)*h_bt.size(), cudaMemcpyHostToDevice, stream)); |
| 1242 | + } |
| 1243 | + |
| 1244 | + // Build schedule_meta on host mirroring DeepGEMM scheduler |
| 1245 | + int dev = ggml_cuda_get_device(); |
| 1246 | + int num_sms = 0; |
| 1247 | + CUDA_CHECK(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev)); |
| 1248 | + ggml_cuda_pool_alloc<unsigned int> __sched(pool, (size_t)(num_sms + 1) * 2); |
| 1249 | + unsigned int *d_sched = __sched.get(); |
| 1250 | + std::vector<unsigned int> h_sched((size_t)(num_sms + 1) * 2, 0); |
| 1251 | + { |
| 1252 | + const int aligned_batch_size = ((batch_size + 31) / 32) * 32; |
| 1253 | + const int split_kv = block_kv * num_math_warp_groups; |
| 1254 | + // For our single-q case, num_segs for q=0 is ceil_div(context_len, split_kv) |
| 1255 | + int context_len = kv_end; |
| 1256 | + int num_segs = (context_len + split_kv - 1) / split_kv; |
| 1257 | + unsigned int total_segs = (unsigned int) num_segs; |
| 1258 | + unsigned int q = total_segs / (unsigned int) num_sms; |
| 1259 | + unsigned int r = total_segs % (unsigned int) num_sms; |
| 1260 | + for (int sm = 0; sm <= num_sms; ++sm) { |
| 1261 | + unsigned int seg_starts = (unsigned int) sm * q + (sm < (int)r ? (unsigned int) sm : (unsigned int) r); |
| 1262 | + unsigned int q_idx = (seg_starts == 0 ? 0u : (unsigned int) batch_size); |
| 1263 | + unsigned int kv_split_idx = (seg_starts == 0 ? 0u : (seg_starts - 1)); |
| 1264 | + h_sched[(size_t)sm*2 + 0] = q_idx; |
| 1265 | + h_sched[(size_t)sm*2 + 1] = kv_split_idx; |
| 1266 | + } |
| 1267 | + } |
| 1268 | + CUDA_CHECK(cudaMemcpyAsync(d_sched, h_sched.data(), sizeof(unsigned int)*h_sched.size(), cudaMemcpyHostToDevice, stream)); |
| 1269 | + |
| 1270 | + // Allocate logits [batch*next_n, aligned_max_context_len] |
| 1271 | + const int aligned_max_context_len = ((max_context_len + num_math_warp_groups*block_kv - 1) / (num_math_warp_groups*block_kv))*(num_math_warp_groups*block_kv); |
| 1272 | + ggml_cuda_pool_alloc<float> __dg_logits(pool, (size_t)batch_size * (size_t)next_n * (size_t)aligned_max_context_len); |
| 1273 | + float *d_dg_logits = __dg_logits.get(); |
| 1274 | + CUDA_CHECK(cudaMemsetAsync(d_dg_logits, 0, sizeof(float)*(size_t)batch_size*(size_t)next_n*(size_t)aligned_max_context_len, stream)); |
| 1275 | + |
| 1276 | + // Build FP8 Q, K and scales in row-major and flatten |
| 1277 | + ggml_cuda_pool_alloc<float> __Qrm(pool, (size_t)(Tc*H) * (size_t)D); |
| 1278 | + ggml_cuda_pool_alloc<float> __Krm(pool, (size_t)kv_end * (size_t)D); |
| 1279 | + ggml_cuda_pool_alloc<float> __Wrm(pool, (size_t)Tc * (size_t)H); |
| 1280 | + float *dQrm = __Qrm.get(); |
| 1281 | + float *dKrm = __Krm.get(); |
| 1282 | + float *dWrm = __Wrm.get(); |
| 1283 | + dim3 tbT(32, 8); |
| 1284 | + dim3 gdQ((Tc*H + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); |
| 1285 | + dim3 gdK((kv_end + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); |
| 1286 | + dim3 gdW((Tc + tbT.x - 1)/tbT.x, (H + tbT.y - 1)/tbT.y); |
| 1287 | + k_colmajor_DN_to_rowmajor_ND<<<gdQ, tbT, 0, stream>>>(dQ, D, Tc*H, dQrm); |
| 1288 | + k_colmajor_DN_to_rowmajor_ND<<<gdK, tbT, 0, stream>>>(dK, D, kv_end, dKrm); |
| 1289 | + k_colmajor_DN_to_rowmajor_ND<<<gdW, tbT, 0, stream>>>(dW, H, Tc, dWrm); |
| 1290 | + |
| 1291 | + // FP8 Q (no per-row scaling) |
| 1292 | + ggml_cuda_pool_alloc<unsigned char> __Qfp8(pool, (size_t)(Tc*H) * (size_t)D); |
| 1293 | + unsigned char *dQfp8 = __Qfp8.get(); |
| 1294 | + { |
| 1295 | + size_t total = (size_t)(Tc*H) * (size_t)D; |
| 1296 | + dim3 tb(256); |
| 1297 | + dim3 gd((unsigned)((total + tb.x - 1)/tb.x)); |
| 1298 | + k_rowmajor_f32_to_fp8_e4m3<<<gd, tb, 0, stream>>>(dQrm, (int)(Tc*H), D, dQfp8); |
| 1299 | + } |
| 1300 | + |
| 1301 | + // FP8 K with per-row scaling and combined k_scale |
| 1302 | + ggml_cuda_pool_alloc<unsigned char> __Kfp8(pool, (size_t)kv_end * (size_t)D); |
| 1303 | + ggml_cuda_pool_alloc<float> __Kamax(pool, (size_t)kv_end); |
| 1304 | + ggml_cuda_pool_alloc<float> __Ksf(pool, (size_t)kv_end); |
| 1305 | + ggml_cuda_pool_alloc<float> __KsfInv(pool, (size_t)kv_end); |
| 1306 | + ggml_cuda_pool_alloc<float> __IdxKScale(pool, (size_t)kv_end); |
| 1307 | + unsigned char *dKfp8 = __Kfp8.get(); |
| 1308 | + float *dKamax = __Kamax.get(); |
| 1309 | + float *dKsf = __Ksf.get(); |
| 1310 | + float *dKsfInv= __KsfInv.get(); |
| 1311 | + float *dIdxKScale = __IdxKScale.get(); |
| 1312 | + { |
| 1313 | + int rowsK = kv_end; |
| 1314 | + int colsK = D; |
| 1315 | + int threadsA = 256; |
| 1316 | + int blocksA = (rowsK + threadsA - 1) / threadsA; |
| 1317 | + k_rowmajor_f32_rowwise_absmax<<<blocksA, threadsA, 0, stream>>>(dKrm, rowsK, colsK, dKamax); |
| 1318 | + k_fp8_compute_row_scales<<<blocksA, threadsA, 0, stream>>>(dKamax, rowsK, dKsf, dKsfInv); |
| 1319 | + size_t total = (size_t)rowsK * (size_t)colsK; |
| 1320 | + dim3 tb(256); |
| 1321 | + dim3 gd((unsigned)((total + tb.x - 1)/tb.x)); |
| 1322 | + k_rowmajor_f32_to_fp8_e4m3_rowwise_scaled<<<gd, tb, 0, stream>>>(dKrm, rowsK, colsK, dKsfInv, dKfp8); |
| 1323 | + k_elemwise_mul<<<blocksA, threadsA, 0, stream>>>(dKS, dKsf, dIdxKScale, rowsK); |
| 1324 | + } |
| 1325 | + |
| 1326 | + // Build fused KV cache buffer [num_kv_blocks, block_kv, 1, head_dim+4] |
| 1327 | + const int head_dim_with_sf = head_dim + 4; |
| 1328 | + ggml_cuda_pool_alloc<unsigned char> __FusedKV(pool, (size_t)num_kv_blocks * (size_t)block_kv * (size_t)head_dim_with_sf); |
| 1329 | + unsigned char *dFusedKV = __FusedKV.get(); |
| 1330 | + CUDA_CHECK(cudaMemsetAsync(dFusedKV, 0, (size_t)num_kv_blocks * (size_t)block_kv * (size_t)head_dim_with_sf, stream)); |
| 1331 | + { |
| 1332 | + // Layout: [num_blocks, block_kv, 1, head_dim+4] |
| 1333 | + dim3 tb(256); |
| 1334 | + dim3 gd((unsigned)((kv_end * head_dim + tb.x - 1)/tb.x)); |
| 1335 | + // pack Kfp8 into leading head_dim slots; then we will separately feed scales via TMA |
| 1336 | + // here we only lay out FP8 values row-major |
| 1337 | + // index: row r in [0,kv_end), col c in [0,head_dim) |
| 1338 | + auto pack = [=] __device__ (int idx) {}; |
| 1339 | + } |
| 1340 | + // For DeepGEMM kernel we actually pass Kfp8 and scales separately via tensor_map_kv and tensor_map_kv_scales, |
| 1341 | + // so we do not need a true fused buffer here; instead we treat dKfp8 as [num_kv_blocks, block_kv, head_dim] |
| 1342 | + |
| 1343 | + // Create TMA descriptors using driver API wrappers |
| 1344 | + cute::TmaDescriptor tma_q{}, tma_kv{}, tma_kv_scales{}, tma_w{}; |
| 1345 | + // Q: [head_dim, batch*next_n*num_heads] |
| 1346 | + dg_fp8_encode_tma_2d(tma_q, CU_TENSOR_MAP_DATA_TYPE_UINT8, |
| 1347 | + dQfp8, |
| 1348 | + (uint32_t) head_dim, |
| 1349 | + (uint32_t)(batch_size * next_n * num_heads), |
| 1350 | + (uint32_t) head_dim, |
| 1351 | + (uint32_t)(next_n * num_heads), |
| 1352 | + (uint32_t) head_dim, |
| 1353 | + sizeof(unsigned char)); |
| 1354 | + // KV: [head_dim, block_kv, num_kv_blocks] |
| 1355 | + dg_fp8_encode_tma_3d(tma_kv, CU_TENSOR_MAP_DATA_TYPE_UINT8, |
| 1356 | + dKfp8, |
| 1357 | + (uint32_t) head_dim, |
| 1358 | + (uint32_t) block_kv, |
| 1359 | + (uint32_t) num_kv_blocks, |
| 1360 | + (uint32_t) head_dim, |
| 1361 | + (uint32_t) block_kv, |
| 1362 | + 1u, |
| 1363 | + (uint32_t) head_dim, |
| 1364 | + (uint32_t)(head_dim * block_kv), |
| 1365 | + sizeof(unsigned char)); |
| 1366 | + // KV scales: [block_kv, num_kv_blocks] |
| 1367 | + dg_fp8_encode_tma_2d(tma_kv_scales, CU_TENSOR_MAP_DATA_TYPE_FLOAT32, |
| 1368 | + dIdxKScale, |
| 1369 | + (uint32_t) block_kv, |
| 1370 | + (uint32_t) num_kv_blocks, |
| 1371 | + (uint32_t) block_kv, |
| 1372 | + 1u, |
| 1373 | + (uint32_t) block_kv, |
| 1374 | + sizeof(float)); |
| 1375 | + // Weights: [next_n*num_heads, batch_size] |
| 1376 | + dg_fp8_encode_tma_2d(tma_w, CU_TENSOR_MAP_DATA_TYPE_FLOAT32, |
| 1377 | + dWrm, |
| 1378 | + (uint32_t)(next_n * num_heads), |
| 1379 | + (uint32_t) batch_size, |
| 1380 | + (uint32_t)(next_n * num_heads), |
| 1381 | + 1u, |
| 1382 | + (uint32_t)(next_n * num_heads), |
| 1383 | + sizeof(float)); |
| 1384 | + |
| 1385 | + // Launch DeepGEMM kernel |
| 1386 | + dim3 grid(num_sms, 1, 1); |
| 1387 | + dim3 block(num_specialized_threads + num_math_threads + 128, 1, 1); |
| 1388 | + const uint32_t kNextN = 2; |
| 1389 | + const uint32_t kNumHeads = 64; |
| 1390 | + const uint32_t kHeadDim = 128; |
| 1391 | + const uint32_t kNumQStages = 3; |
| 1392 | + const uint32_t kNumKVStages = 3; |
| 1393 | + const uint32_t SPLIT_KV = block_kv * num_math_warp_groups; |
| 1394 | + const uint32_t kNumSpecializedThreads = 128; |
| 1395 | + const uint32_t kNumMathThreads = num_math_warp_groups * 128; |
| 1396 | + size_t shmem_bytes = 0; // computed in-kernel |
| 1397 | + |
| 1398 | + LAUNCH_PROFILE_KERNEL("PROFILE_DG_FP8", DG_FP8, stream, ([&](){ |
| 1399 | + deep_gemm::sm100_fp8_paged_mqa_logits< |
| 1400 | + kNextN, kNumHeads, |
| 1401 | + kHeadDim, (uint32_t) block_kv, |
| 1402 | + false, |
| 1403 | + kNumQStages, kNumKVStages, |
| 1404 | + SPLIT_KV, |
| 1405 | + kNumSpecializedThreads, kNumMathThreads |
| 1406 | + ><<<grid, block, shmem_bytes, stream>>>( |
| 1407 | + (uint32_t) batch_size, |
| 1408 | + (uint64_t) aligned_max_context_len, |
| 1409 | + (uint64_t) num_kv_blocks, |
| 1410 | + d_ctx_lens, |
| 1411 | + d_dg_logits, |
| 1412 | + d_block_tbl, |
| 1413 | + d_sched, |
| 1414 | + tma_q, |
| 1415 | + tma_kv, |
| 1416 | + tma_kv_scales, |
| 1417 | + tma_w); |
| 1418 | + })(), D, H, Tc, kv_end); |
| 1419 | + |
| 1420 | + // Map logits [batch*next_n, max_context_len] back to Out [kv, Tc] |
| 1421 | + // For our simple case (batch=1), token t in [0,Tc), kv index k in [0,kv_end) |
| 1422 | + dim3 tb(32, 4); |
| 1423 | + dim3 gd((kv_end + tb.x - 1)/tb.x, (Tc + tb.y - 1)/tb.y); |
| 1424 | + k_transpose_TcKv_to_KvTc<<<gd, tb, 0, stream>>>(d_dg_logits, Tc, kv_end, dOut); |
| 1425 | + CUDA_CHECK(cudaGetLastError()); |
| 1426 | + cudaStreamSynchronize(stream); |
| 1427 | + if (dStarts_tmp) cudaFree(dStarts_tmp); |
| 1428 | + if (dEnds_tmp) cudaFree(dEnds_tmp); |
| 1429 | + return; |
| 1430 | + } |
| 1431 | +#endif // CUDART_VERSION >= 12000 |
| 1432 | + } else if (const char *s = getenv("LLAMA_INDEXER_TL_PORT"); s && atoi(s) != 0) { |
| 1433 | + |
1153 | 1434 | ggml_cuda_pool & __pool = ctx.pool(ggml_cuda_get_device()); |
1154 | 1435 | bool use_tma_fp8 = false; |
1155 | 1436 | if (const char *e = getenv("LLAMA_TL_FP8"); e && atoi(e) != 0) use_tma_fp8 = true; |
|
0 commit comments