Skip to content

Commit c73d643

Browse files
authored
Merge pull request #106 from SmallDoges/Support-backward
2 parents bbeac61 + 2959585 commit c73d643

File tree

9 files changed

+1070
-217
lines changed

9 files changed

+1070
-217
lines changed

csrc/flash_api.cpp

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ void set_params_fprop(
7878
params.mask_head_stride = mask.stride(-3);
7979
params.bias_head_stride = bias.stride(-3);
8080
params.o_head_stride = out.stride(-2);
81-
params.mask_col_stride = mask.stride(-1);
82-
params.bias_col_stride = bias.stride(-1);
8381

8482
if (cu_seqlens_q_d == nullptr) {
8583
params.q_batch_stride = q.stride(0);
@@ -142,6 +140,93 @@ void set_params_fprop(
142140
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
143141
}
144142

143+
void set_params_dgrad(
144+
Flash_bwd_params &params,
145+
// sizes
146+
const size_t b,
147+
const size_t seqlen_q,
148+
const size_t seqlen_k,
149+
const size_t seqlen_q_rounded,
150+
const size_t seqlen_k_rounded,
151+
const size_t h,
152+
const size_t h_k,
153+
const size_t d,
154+
const size_t d_rounded,
155+
// device pointers
156+
const at::Tensor q,
157+
const at::Tensor k,
158+
const at::Tensor v,
159+
const at::Tensor mask,
160+
const at::Tensor bias,
161+
const at::Tensor out,
162+
const at::Tensor dout,
163+
at::Tensor dq,
164+
at::Tensor dk,
165+
at::Tensor dv,
166+
at::Tensor dbias,
167+
void *cu_seqlens_q_d,
168+
void *cu_seqlens_k_d,
169+
void *dq_accum_d,
170+
void *dk_accum_d,
171+
void *dv_accum_d,
172+
void *softmax_lse_d,
173+
void *dsoftmax_sum_d,
174+
float p_dropout,
175+
float softmax_scale,
176+
bool is_causal,
177+
const float softcap,
178+
bool deterministic,
179+
const bool unpadded_lse
180+
) {
181+
set_params_fprop(
182+
params,
183+
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
184+
q, k, v, mask, bias, out,
185+
cu_seqlens_q_d,
186+
cu_seqlens_k_d,
187+
nullptr,
188+
nullptr,
189+
softmax_lse_d,
190+
softmax_scale,
191+
is_causal,
192+
softcap,
193+
false, // seqlenq_ngroups_swapped
194+
unpadded_lse
195+
);
196+
197+
// Set the pointers and strides.
198+
params.do_ptr = dout.data_ptr();
199+
params.do_row_stride = dout.stride(-3);
200+
params.do_head_stride = dout.stride(-2);
201+
params.dq_ptr = dq.data_ptr();
202+
params.dk_ptr = dk.data_ptr();
203+
params.dv_ptr = dv.data_ptr();
204+
params.dbias_ptr = dbias.data_ptr();
205+
params.dq_row_stride = dq.stride(-3);
206+
params.dk_row_stride = dk.stride(-3);
207+
params.dv_row_stride = dv.stride(-3);
208+
params.dq_head_stride = dq.stride(-2);
209+
params.dk_head_stride = dk.stride(-2);
210+
params.dv_head_stride = dv.stride(-2);
211+
212+
if (cu_seqlens_q_d == nullptr) {
213+
params.do_batch_stride = dout.stride(0);
214+
params.dq_batch_stride = dq.stride(0);
215+
params.dk_batch_stride = dk.stride(0);
216+
params.dv_batch_stride = dv.stride(0);
217+
params.dbias_batch_stride = dbias.stride(0);
218+
}
219+
220+
params.dq_accum_ptr = dq_accum_d;
221+
params.dk_accum_ptr = dk_accum_d;
222+
params.dv_accum_ptr = dv_accum_d;
223+
224+
// Softmax sum
225+
params.dsoftmax_sum = dsoftmax_sum_d;
226+
227+
params.deterministic = deterministic;
228+
}
229+
145230
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
146231
FP16_SWITCH(!params.is_bf16, [&] {
147232
HEADDIM_SWITCH(params.d, [&] {
@@ -233,14 +318,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
233318
}
234319
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
235320

236-
// Temporarily disable Split-KV, because some bugs are still being fixed.
237-
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
238-
// Regardless of how it is set externally, always set num_splits back to 1.
239-
// This is to avoid the extra memory overhead of Split-KV.
240-
params.num_splits = 1;
241-
softmax_lse_accum.reset();
242-
out_accum.reset();
243-
244321
return std::make_tuple(softmax_lse_accum, out_accum);
245322
}
246323

csrc/src/block_info.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ struct BlockInfo {
3636
}
3737

3838
template <typename index_t>
39-
__forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const {
39+
__forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
4040
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
41-
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
41+
sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k);
4242
return offset;
4343
}
4444

4545
template <typename index_t>
46-
__forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const {
46+
__forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
4747
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
48-
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
48+
sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k);
4949
return offset;
5050
}
5151

csrc/src/flash.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ struct Mask_params {
5050
index_t mask_batch_stride; // Stride between batches of attention mask
5151
index_t mask_head_stride; // Stride between heads of attention mask
5252
index_t mask_row_stride; // Stride between rows of attention mask
53-
index_t mask_col_stride; // Stride between columns of attention mask
5453
};
5554

5655
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -62,7 +61,6 @@ struct Bias_params {
6261
index_t bias_batch_stride; // Stride between batches of attention bias
6362
index_t bias_head_stride; // Stride between heads of attention bias
6463
index_t bias_row_stride; // Stride between rows of attention bias
65-
index_t bias_col_stride; // Stride between columns of attention bias
6664
};
6765

6866
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -179,7 +177,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
179177
index_t dbias_batch_stride;
180178
index_t dbias_head_stride;
181179
index_t dbias_row_stride;
182-
index_t dbias_col_stride;
183180

184181
// The pointer to the softmax d sum.
185182
void *__restrict__ dsoftmax_sum;

0 commit comments

Comments
 (0)