Skip to content

Commit edf840c

Browse files
committed
Adds mask and bias memory copy operations
Introduces dedicated global-to-shared memory copy operations for mask and bias tensors in the backward kernel computation function. Enables proper handling of attention masks and bias terms during gradient computation by creating separate thread slices for these operations.
1 parent 18db4a1 commit edf840c

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
274274
// Global to Shared Memory operation
275275
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
276276
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
277+
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
278+
auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx);
279+
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
280+
auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx);
277281
using GmemTiledCopydO = std::conditional_t<Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV>;
278282
GmemTiledCopydO gmem_tiled_copy_dO;
279283
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);

0 commit comments

Comments
 (0)