Skip to content

Commit 5655a9d

Browse files
committed
Adds mask and bias copying from shared memory to registers
Implements tensor copying operations to move mask and bias data from shared memory to register storage before computation. Creates register tensors with matching shapes and uses retiled copy views to efficiently transfer the data, preparing for subsequent processing steps.
1 parent a092dfd commit 5655a9d

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
621621
cute::cp_async_wait<0>();
622622
__syncthreads();
623623

624+
// Copy mask and bias from smem to registers
625+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
626+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
627+
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
628+
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
629+
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
630+
cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view);
631+
624632
Tensor dP_sum = make_fragment_like(lse);
625633
#pragma unroll
626634
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }

0 commit comments

Comments
 (0)