@@ -78,8 +78,6 @@ void set_params_fprop(
7878 params.mask_head_stride = mask.stride (-3 );
7979 params.bias_head_stride = bias.stride (-3 );
8080 params.o_head_stride = out.stride (-2 );
81- params.mask_col_stride = mask.stride (-1 );
82- params.bias_col_stride = bias.stride (-1 );
8381
8482 if (cu_seqlens_q_d == nullptr ) {
8583 params.q_batch_stride = q.stride (0 );
@@ -142,6 +140,93 @@ void set_params_fprop(
142140 params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
143141}
144142
143+ void set_params_dgrad (
144+ Flash_bwd_params ¶ms,
145+ // sizes
146+ const size_t b,
147+ const size_t seqlen_q,
148+ const size_t seqlen_k,
149+ const size_t seqlen_q_rounded,
150+ const size_t seqlen_k_rounded,
151+ const size_t h,
152+ const size_t h_k,
153+ const size_t d,
154+ const size_t d_rounded,
155+ // device pointers
156+ const at::Tensor q,
157+ const at::Tensor k,
158+ const at::Tensor v,
159+ const at::Tensor mask,
160+ const at::Tensor bias,
161+ const at::Tensor out,
162+ const at::Tensor dout,
163+ at::Tensor dq,
164+ at::Tensor dk,
165+ at::Tensor dv,
166+ at::Tensor dbias,
167+ void *cu_seqlens_q_d,
168+ void *cu_seqlens_k_d,
169+ void *dq_accum_d,
170+ void *dk_accum_d,
171+ void *dv_accum_d,
172+ void *softmax_lse_d,
173+ void *dsoftmax_sum_d,
174+ float p_dropout,
175+ float softmax_scale,
176+ bool is_causal,
177+ const float softcap,
178+ bool deterministic,
179+ const bool unpadded_lse
180+ ) {
181+ set_params_fprop (
182+ params,
183+ b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
184+ q, k, v, mask, bias, out,
185+ cu_seqlens_q_d,
186+ cu_seqlens_k_d,
187+ nullptr ,
188+ nullptr ,
189+ softmax_lse_d,
190+ softmax_scale,
191+ is_causal,
192+ softcap,
193+ false , // seqlenq_ngroups_swapped
194+ unpadded_lse
195+ );
196+
197+ // Set the pointers and strides.
198+ params.do_ptr = dout.data_ptr ();
199+ params.do_row_stride = dout.stride (-3 );
200+ params.do_head_stride = dout.stride (-2 );
201+ params.dq_ptr = dq.data_ptr ();
202+ params.dk_ptr = dk.data_ptr ();
203+ params.dv_ptr = dv.data_ptr ();
204+ params.dbias_ptr = dbias.data_ptr ();
205+ params.dq_row_stride = dq.stride (-3 );
206+ params.dk_row_stride = dk.stride (-3 );
207+ params.dv_row_stride = dv.stride (-3 );
208+ params.dq_head_stride = dq.stride (-2 );
209+ params.dk_head_stride = dk.stride (-2 );
210+ params.dv_head_stride = dv.stride (-2 );
211+
212+ if (cu_seqlens_q_d == nullptr ) {
213+ params.do_batch_stride = dout.stride (0 );
214+ params.dq_batch_stride = dq.stride (0 );
215+ params.dk_batch_stride = dk.stride (0 );
216+ params.dv_batch_stride = dv.stride (0 );
217+ params.dbias_batch_stride = dbias.stride (0 );
218+ }
219+
220+ params.dq_accum_ptr = dq_accum_d;
221+ params.dk_accum_ptr = dk_accum_d;
222+ params.dv_accum_ptr = dv_accum_d;
223+
224+ // Softmax sum
225+ params.dsoftmax_sum = dsoftmax_sum_d;
226+
227+ params.deterministic = deterministic;
228+ }
229+
145230void run_mha_fwd (Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false ) {
146231 FP16_SWITCH (!params.is_bf16 , [&] {
147232 HEADDIM_SWITCH (params.d , [&] {
@@ -233,14 +318,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
233318 }
234319 TORCH_CHECK (params.num_splits <= 128 , " num_splits > 128 not supported" );
235320
236- // Temporarily disable Split-KV, because some bugs are still being fixed.
237- // See: https://github.com/SmallDoges/flash-dmattn/issues/47
238- // Regardless of how it is set externally, always set num_splits back to 1.
239- // This is to avoid the extra memory overhead of Split-KV.
240- params.num_splits = 1 ;
241- softmax_lse_accum.reset ();
242- out_accum.reset ();
243-
244321 return std::make_tuple (softmax_lse_accum, out_accum);
245322}
246323
0 commit comments