Skip to content

Commit 5ee5acf

Browse files
committed
Adds mask and bias tensor support to backward kernel
Introduces gMask, gBias, and gdBias tensor declarations to enable attention masking and bias functionality in the backward pass. Extends the kernel to handle masked attention computations and bias gradient calculations for more flexible attention mechanisms.
1 parent be9c3ac commit 5ee5acf

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
143143
Shape<Int<kBlockN>, Int<kHeadDim>>{},
144144
make_stride(params.v_row_stride, _1{})
145145
);
146+
Tensor gMask = make_tensor(
147+
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + row_offset_mask),
148+
Shape<Int<kBlockM>, Int<kBlockN>>{},
149+
make_stride(params.mask_row_stride, _1{})
150+
);
151+
Tensor gBias = make_tensor(
152+
make_gmem_ptr(reinterpret_cast<Element *>(params.bias_ptr) + row_offset_bias),
153+
Shape<Int<kBlockM>, Int<kBlockN>>{},
154+
make_stride(params.bias_row_stride, _1{})
155+
);
156+
Tensor gdBias = make_tensor(
157+
make_gmem_ptr(reinterpret_cast<Element *>(params.dbias_ptr) + row_offset_dbias),
158+
Shape<Int<kBlockM>, Int<kBlockN>>{},
159+
make_stride(params.dbias_row_stride, _1{})
160+
);
146161
Tensor gdO = make_tensor(
147162
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
148163
Shape<Int<kBlockM>, Int<kHeadDim>>{},

0 commit comments

Comments
 (0)