Skip to content

Commit 86c3e52

Browse files
committed
Updates memory copy operations for mask and bias tensors
Replaces standard tiled copy operations with warp-contiguous variants for improved memory access patterns. Changes from generic make_tiled_copy_C to make_tiled_copy_C_warpcontiguousN which optimizes memory layout for better performance in GPU kernels.
1 parent 9900589 commit 86c3e52

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
354354
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
355355
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
356356

357+
// auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
358+
// auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
359+
// Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
360+
// auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
361+
// auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
362+
// Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
363+
364+
auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
365+
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
366+
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
367+
auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
368+
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
369+
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
370+
357371
// Partition sP and sdS to match the accumulator partitioning
358372
// This has to be tiled_mma_sdp, not tiled_mma_dkv
359373
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);

0 commit comments

Comments
 (0)