Skip to content

Commit 6c53edf

Browse files
committed
Refactors shared memory copy initialization code
Simplifies the creation of tiled shared memory copy objects by removing the intermediate step of getting thread slices directly from the factory function. The refactored approach creates the tiled copy object first, then obtains the thread slice separately for better code clarity.
1 parent 86c3e52 commit 6c53edf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
370370

371371
// Partition sP and sdS to match the accumulator partitioning
372372
// This has to be tiled_mma_sdp, not tiled_mma_dkv
373-
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
373+
// auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
374374
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
375375
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
376376
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom, AtomNum), PIPE_M, PIPE_N)

0 commit comments

Comments
 (0)