Skip to content

Commit d0169a5

Browse files
committed
Adds tensor reshaping for mask and bias in backward kernel
Extends the existing tensor reshaping logic to include mask and bias tensors alongside the scores tensor. All three tensors now use the same layout conversion from MMA format to row-column format, ensuring consistent tensor structure for subsequent computations.
1 parent 5655a9d commit d0169a5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
651651
FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
652652
}
653653

654-
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
654+
// Reshape acc_s, mask, bias from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
655655
Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout()));
656+
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
657+
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));
656658
// if (cute::thread(32, 0)) { print(scores); }
657659

658660
// Softcapping - calculating dTanh and scaling dS later with it

0 commit comments

Comments
 (0)