Skip to content

Commit be9c3ac

Browse files
committed
Adds mask and bias offset calculations to backward kernel
Introduces row offset computations for mask, bias, and bias gradient tensors in the backward pass computation function. Enables proper memory addressing for attention mask and bias operations during gradient computation by calculating the appropriate stride-based offsets for batch, head, and spatial dimensions.
1 parent 6d6ab5d commit be9c3ac

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
107107
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
108108
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
109109
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
110+
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
111+
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
112+
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
113+
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
114+
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
115+
+ (bidh / params.h_h_k_ratio) * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
110116
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
111117
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
112118
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)

0 commit comments

Comments
 (0)