Skip to content

Commit 7b90c9c

Browse files
committed
Adds mask and bias tensor partitioning support
Introduces tensor partitioning for mask and bias operations in the backward kernel computation function. Sets up the necessary tensor views for mask and bias data structures to enable proper memory access patterns during gradient computation.
1 parent edf840c commit 7b90c9c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
300300
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
301301
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
302302
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
303+
Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N)
304+
Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask);
305+
Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N)
306+
Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias);
307+
308+
303309
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
304310
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
305311
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);

0 commit comments

Comments
 (0)