Skip to content

Commit 18db4a1

Browse files
committed
Adds bias and mask tensor support to shared memory layout
Introduces dedicated shared memory tensors for mask and bias operations, reorganizing memory allocation to accommodate new tensor types. Updates memory pointer calculations to maintain proper offset alignment for existing value and gradient tensors after bias tensor insertion.
1 parent 5ee5acf commit 18db4a1

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
225225
sK.data(),
226226
typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}
227227
);
228-
Tensor sV = make_tensor(
228+
Tensor sMask = make_tensor(
229229
sK.data() + size(sK),
230+
typename Kernel_traits::SmemLayoutMask{}
231+
);
232+
Tensor sBias = make_tensor(
233+
sMask.data() + size(sMask),
234+
typename Kernel_traits::SmemLayoutBias{}
235+
);
236+
Tensor sdBias = make_tensor(
237+
sBias.data(),
238+
typename Kernel_traits::SmemLayoutBias{}
239+
);
240+
Tensor sV = make_tensor(
241+
sBias.data() + size(sBias),
230242
typename Kernel_traits::SmemLayoutKV{}
231243
);
232244
Tensor sdS = make_tensor(
233-
!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
245+
!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sBias.data() + size(sBias),
234246
typename Kernel_traits::SmemLayoutPdS{}
235247
);
236248
Tensor sdSt = make_tensor(

0 commit comments

Comments
 (0)