Skip to content

Commit 0f60141

Browse files
committed
Adds backward pass parameter setup and removes unused stride parameters
Introduces a new function to configure parameters for the backward gradient computation pass, enabling proper gradient calculation for the flash attention implementation. Removes unused mask and bias column stride parameters from the forward pass setup to clean up the parameter structure. Re-enables Split-KV functionality by removing the temporary disable workaround that was previously implemented due to bug fixes.
1 parent 8ddcc34 commit 0f60141

File tree

1 file changed

+87
-10
lines changed

1 file changed

+87
-10
lines changed

csrc/flash_api.cpp

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ void set_params_fprop(
7878
params.mask_head_stride = mask.stride(-3);
7979
params.bias_head_stride = bias.stride(-3);
8080
params.o_head_stride = out.stride(-2);
81-
params.mask_col_stride = mask.stride(-1);
82-
params.bias_col_stride = bias.stride(-1);
8381

8482
if (cu_seqlens_q_d == nullptr) {
8583
params.q_batch_stride = q.stride(0);
@@ -142,6 +140,93 @@ void set_params_fprop(
142140
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
143141
}
144142

143+
void set_params_dgrad(
144+
Flash_bwd_params &params,
145+
// sizes
146+
const size_t b,
147+
const size_t seqlen_q,
148+
const size_t seqlen_k,
149+
const size_t seqlen_q_rounded,
150+
const size_t seqlen_k_rounded,
151+
const size_t h,
152+
const size_t h_k,
153+
const size_t d,
154+
const size_t d_rounded,
155+
// device pointers
156+
const at::Tensor q,
157+
const at::Tensor k,
158+
const at::Tensor v,
159+
const at::Tensor mask,
160+
const at::Tensor bias,
161+
const at::Tensor out,
162+
const at::Tensor dout,
163+
at::Tensor dq,
164+
at::Tensor dk,
165+
at::Tensor dv,
166+
at::Tensor dbias,
167+
void *cu_seqlens_q_d,
168+
void *cu_seqlens_k_d,
169+
void *dq_accum_d,
170+
void *dk_accum_d,
171+
void *dv_accum_d,
172+
void *softmax_lse_d,
173+
void *dsoftmax_sum_d,
174+
float p_dropout,
175+
float softmax_scale,
176+
bool is_causal,
177+
const float softcap,
178+
bool deterministic,
179+
const bool unpadded_lse
180+
) {
181+
set_params_fprop(
182+
params,
183+
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
184+
q, k, v, mask, bias, out,
185+
cu_seqlens_q_d,
186+
cu_seqlens_k_d,
187+
nullptr,
188+
nullptr,
189+
softmax_lse_d,
190+
softmax_scale,
191+
is_causal,
192+
softcap,
193+
false, // seqlenq_ngroups_swapped
194+
unpadded_lse
195+
);
196+
197+
// Set the pointers and strides.
198+
params.do_ptr = dout.data_ptr();
199+
params.do_row_stride = dout.stride(-3);
200+
params.do_head_stride = dout.stride(-2);
201+
params.dq_ptr = dq.data_ptr();
202+
params.dk_ptr = dk.data_ptr();
203+
params.dv_ptr = dv.data_ptr();
204+
params.dbias_ptr = dbias.data_ptr();
205+
params.dq_row_stride = dq.stride(-3);
206+
params.dk_row_stride = dk.stride(-3);
207+
params.dv_row_stride = dv.stride(-3);
208+
params.dq_head_stride = dq.stride(-2);
209+
params.dk_head_stride = dk.stride(-2);
210+
params.dv_head_stride = dv.stride(-2);
211+
212+
if (cu_seqlens_q_d == nullptr) {
213+
params.do_batch_stride = dout.stride(0);
214+
params.dq_batch_stride = dq.stride(0);
215+
params.dk_batch_stride = dk.stride(0);
216+
params.dv_batch_stride = dv.stride(0);
217+
params.dbias_batch_stride = dbias.stride(0);
218+
}
219+
220+
params.dq_accum_ptr = dq_accum_d;
221+
params.dk_accum_ptr = dk_accum_d;
222+
params.dv_accum_ptr = dv_accum_d;
223+
224+
// Softmax sum
225+
params.dsoftmax_sum = dsoftmax_sum_d;
226+
227+
params.deterministic = deterministic;
228+
}
229+
145230
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
146231
FP16_SWITCH(!params.is_bf16, [&] {
147232
HEADDIM_SWITCH(params.d, [&] {
@@ -233,14 +318,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
233318
}
234319
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
235320

236-
// Temporarily disable Split-KV, because some bugs are still being fixed.
237-
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
238-
// Regardless of how it is set externally, always set num_splits back to 1.
239-
// This is to avoid the extra memory overhead of Split-KV.
240-
params.num_splits = 1;
241-
softmax_lse_accum.reset();
242-
out_accum.reset();
243-
244321
return std::make_tuple(softmax_lse_accum, out_accum);
245322
}
246323

0 commit comments

Comments
 (0)