Skip to content

Commit 7a44864

Browse files
committed
Optimizes kernel configurations for memory usage
Reduces block dimensions and updates shared memory comments to better fit within hardware limits across different GPU architectures. Adjusts kernel traits parameters to use smaller block sizes while maintaining performance, enabling better utilization of available shared memory on devices with varying SMEM capacities. Removes commented-out experimental configurations that are no longer needed.
1 parent 43d73e8 commit 7a44864

File tree

1 file changed

+22
-40
lines changed

1 file changed

+22
-40
lines changed

csrc/src/flash_bwd_launch_template.h

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
165165
}
166166
// printf("max_smem_per_block = %d\n", max_smem_per_block);
167167
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
168-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
169-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
170-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
171-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
172168
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
173169
if (max_smem_per_block >= 144 * 1024) {
174-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
170+
// 122KB
171+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
175172
// This has a lot of register spilling
176-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
173+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
177174
} else {
178-
// if (params.h == params.h_k) {
179-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
180-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
181-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
182-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
183-
// } else {
184-
// }
175+
// 74KB
176+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
185177
}
186-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
187-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
188-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
189-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
190178
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
191-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
192-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
193-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
194-
195-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
196179
}
197180

198181
template<typename T, bool Is_causal>
@@ -208,10 +191,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
208191
}
209192
// printf("max_smem_per_block = %d\n", max_smem_per_block);
210193
if (max_smem_per_block >= 116 * 1024) {
211-
// 92KB
212-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
194+
// 94KB
195+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
213196
} else {
214-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
197+
// 94KB
198+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
215199
}
216200
}
217201

@@ -227,24 +211,17 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
227211
C10_CUDA_CHECK(status_);
228212
}
229213
// printf("max_smem_per_block = %d\n", max_smem_per_block);
230-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
214+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
231215
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
232216
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
233-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
217+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
234218
if (max_smem_per_block >= 144 * 1024) {
219+
// 114KB
235220
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
236-
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>>(params, stream);
237-
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>>(params, stream);
238-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
239-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
240-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>>(params, stream);
241221
} else {
242-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
243-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
222+
// 74KB
223+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
244224
}
245-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
246-
247-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
248225
}
249226

250227
template<typename T, bool Is_causal>
@@ -259,9 +236,11 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
259236
C10_CUDA_CHECK(status_);
260237
}
261238
if (max_smem_per_block >= 136 * 1024) {
239+
// 156KB
262240
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
263241
} else {
264-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
242+
// 102KB
243+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
265244
}
266245
}
267246

@@ -277,11 +256,14 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
277256
C10_CUDA_CHECK(status_);
278257
}
279258
if (max_smem_per_block >= 176 * 1024) { // H100
259+
// 196KB
280260
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
281261
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
282-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
262+
// 131KB
263+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
283264
} else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering.
284-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
265+
// 90KB
266+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
285267
}
286268
}
287269

0 commit comments

Comments
 (0)