Skip to content

Commit f7459fe

Browse files
committed
Fixes template parameter order in flash backward kernel
Corrects the template parameter list by removing an extra `false` parameter that was causing compilation issues or incorrect behavior in the kernel traits instantiation for devices with limited shared memory.
1 parent d9a4d5a commit f7459fe

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/src/flash_bwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
274274
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
275275
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
276276
} else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering.
277-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
277+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
278278
}
279279
}
280280

0 commit comments

Comments
 (0)