|
| 1 | +/****************************************************************************** |
| 2 | + * Copyright (c) 2025, Jingze Shi and Tri Dao. |
| 3 | + ******************************************************************************/ |
| 4 | + |
| 5 | +#pragma once |
| 6 | + |
| 7 | +#include "namespace_config.h" |
| 8 | +#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK |
| 9 | + |
| 10 | +#include "static_switch.h" |
| 11 | +#include "hardware_info.h" |
| 12 | +#include "flash.h" |
| 13 | +#include "flash_bwd_preprocess_kernel.h" |
| 14 | +#include "flash_bwd_kernel.h" |
| 15 | + |
| 16 | +namespace FLASH_NAMESPACE { |
| 17 | + |
| 18 | +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers |
| 19 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 20 | +#define ARCH_SUPPORTS_FLASH |
| 21 | +#define KERNEL_PARAM_MODIFIER __grid_constant__ |
| 22 | +#else |
| 23 | +#define KERNEL_PARAM_MODIFIER |
| 24 | +#endif |
| 25 | + |
| 26 | +// Define a macro for unsupported architecture handling to centralize the error message |
| 27 | +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); |
| 28 | + |
| 29 | +// Use a macro to clean up kernel definitions |
| 30 | +#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ |
| 31 | +template<typename Kernel_traits, __VA_ARGS__> \ |
| 32 | +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) |
| 33 | + |
| 34 | +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { |
| 35 | + #if defined(ARCH_SUPPORTS_FLASH) |
| 36 | + FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params); |
| 37 | + #else |
| 38 | + FLASH_UNSUPPORTED_ARCH |
| 39 | + #endif |
| 40 | +} |
| 41 | + |
| 42 | +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { |
| 43 | + #if defined(ARCH_SUPPORTS_FLASH) |
| 44 | + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false |
| 45 | + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params); |
| 46 | + #else |
| 47 | + FLASH_UNSUPPORTED_ARCH |
| 48 | + #endif |
| 49 | +} |
| 50 | + |
| 51 | + |
| 52 | +template<bool Clear_dQaccum=true, typename Kernel_traits> |
| 53 | +__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { |
| 54 | + FLASH_NAMESPACE::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params); |
| 55 | +} |
| 56 | + |
| 57 | +template<typename Kernel_traits> |
| 58 | +__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { |
| 59 | + FLASH_NAMESPACE::clear_dKVaccum<Kernel_traits>(params); |
| 60 | +} |
| 61 | + |
| 62 | +template<typename Kernel_traits> |
| 63 | +__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { |
| 64 | + FLASH_NAMESPACE::convert_dQ<Kernel_traits>(params, nsplits); |
| 65 | +} |
| 66 | + |
| 67 | +template<typename Kernel_traits> |
| 68 | +__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { |
| 69 | + FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params); |
| 70 | +} |
| 71 | + |
| 72 | +template<typename Kernel_traits, bool Is_causal> |
| 73 | +void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 74 | + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; |
| 75 | + dim3 grid_m(num_m_block, params.b, params.h); |
| 76 | + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; |
| 77 | + int gridDimx = num_n_block; |
| 78 | + if (params.deterministic) { |
| 79 | + int num_sm = get_num_sm(get_current_device()); |
| 80 | + gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h); |
| 81 | + } |
| 82 | + dim3 grid_n(gridDimx, params.b, params.h); |
| 83 | + |
| 84 | + if (!params.deterministic) { |
| 85 | + flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params); |
| 86 | + } else { |
| 87 | + flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params); |
| 88 | + } |
| 89 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 90 | + |
| 91 | + // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not |
| 92 | + // a multiple of kBlockN, we'll need to apply mask in the loop. |
| 93 | + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; |
| 94 | + const bool is_even_K = params.d == Kernel_traits::kHeadDim; |
| 95 | + constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; |
| 96 | + // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); |
| 97 | + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { |
| 98 | + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { |
| 99 | + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { |
| 100 | + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { |
| 101 | + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { |
| 102 | + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. |
| 103 | + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates |
| 104 | + // If Is_local, set Is_causal to false |
| 105 | + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>; |
| 106 | + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>; |
| 107 | + if (smem_size_dq_dk_dv >= 48 * 1024) { |
| 108 | + C10_CUDA_CHECK(cudaFuncSetAttribute( |
| 109 | + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); |
| 110 | + } |
| 111 | + kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); |
| 112 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 113 | + }); |
| 114 | + }); |
| 115 | + }); |
| 116 | + }); |
| 117 | + }); |
| 118 | + |
| 119 | + auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>; |
| 120 | + if (Kernel_traits::kSmemdQSize >= 48 * 1024) { |
| 121 | + C10_CUDA_CHECK(cudaFuncSetAttribute( |
| 122 | + kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); |
| 123 | + } |
| 124 | + kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx); |
| 125 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 126 | +} |
| 127 | + |
| 128 | +template<typename Kernel_traits, bool Is_causal> |
| 129 | +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 130 | +#ifndef FLASHATTENTION_DISABLE_BACKWARD |
| 131 | + run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream); |
| 132 | +#endif |
| 133 | +} |
| 134 | + |
| 135 | +template<typename T, bool Is_causal> |
| 136 | +void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 137 | + constexpr static int Headdim = 32; |
| 138 | + int device; |
| 139 | + cudaGetDevice(&device); |
| 140 | + int max_smem_per_block; |
| 141 | + cudaError status_ = cudaDeviceGetAttribute( |
| 142 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 143 | + if (status_ != cudaSuccess) { |
| 144 | + C10_CUDA_CHECK(status_); |
| 145 | + } |
| 146 | + if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB |
| 147 | + // We can afford more registers to keep V in registers |
| 148 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream); |
| 149 | + } else { // 96 KB |
| 150 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream); |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +template<typename T, bool Is_causal> |
| 155 | +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 156 | + constexpr static int Headdim = 64; |
| 157 | + int device; |
| 158 | + cudaGetDevice(&device); |
| 159 | + int max_smem_per_block; |
| 160 | + cudaError status_ = cudaDeviceGetAttribute( |
| 161 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 162 | + if (status_ != cudaSuccess) { |
| 163 | + C10_CUDA_CHECK(status_); |
| 164 | + } |
| 165 | + // printf("max_smem_per_block = %d\n", max_smem_per_block); |
| 166 | + // Changing AtomLayoutMdQ from 2 to 4 takes the same time |
| 167 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream); |
| 168 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream); |
| 169 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream); |
| 170 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream); |
| 171 | + // This is slightly faster. We want to split M more so we need fewer registers to store LSE. |
| 172 | + if (max_smem_per_block >= 144 * 1024) { |
| 173 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream); |
| 174 | + // This has a lot of register spilling |
| 175 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream); |
| 176 | + } else { |
| 177 | + // if (params.h == params.h_k) { |
| 178 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream); |
| 179 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream); |
| 180 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream); |
| 181 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream); |
| 182 | + // } else { |
| 183 | + // } |
| 184 | + } |
| 185 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream); |
| 186 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream); |
| 187 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream); |
| 188 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream); |
| 189 | + // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times |
| 190 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream); |
| 191 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream); |
| 192 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream); |
| 193 | + |
| 194 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream); |
| 195 | +} |
| 196 | + |
| 197 | +template<typename T, bool Is_causal> |
| 198 | +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 199 | + constexpr static int Headdim = 96; |
| 200 | + int device; |
| 201 | + cudaGetDevice(&device); |
| 202 | + int max_smem_per_block; |
| 203 | + cudaError status_ = cudaDeviceGetAttribute( |
| 204 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 205 | + if (status_ != cudaSuccess) { |
| 206 | + C10_CUDA_CHECK(status_); |
| 207 | + } |
| 208 | + // printf("max_smem_per_block = %d\n", max_smem_per_block); |
| 209 | + if (max_smem_per_block >= 116 * 1024) { |
| 210 | + // 92KB |
| 211 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream); |
| 212 | + } else { |
| 213 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream); |
| 214 | + } |
| 215 | +} |
| 216 | + |
| 217 | +template<typename T, bool Is_causal> |
| 218 | +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 219 | + constexpr static int Headdim = 128; |
| 220 | + int device; |
| 221 | + cudaGetDevice(&device); |
| 222 | + int max_smem_per_block; |
| 223 | + cudaError status_ = cudaDeviceGetAttribute( |
| 224 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 225 | + if (status_ != cudaSuccess) { |
| 226 | + C10_CUDA_CHECK(status_); |
| 227 | + } |
| 228 | + // printf("max_smem_per_block = %d\n", max_smem_per_block); |
| 229 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream); |
| 230 | + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). |
| 231 | + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. |
| 232 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream); |
| 233 | + if (max_smem_per_block >= 144 * 1024) { |
| 234 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream); |
| 235 | + // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>>(params, stream); |
| 236 | + // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>>(params, stream); |
| 237 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream); |
| 238 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>>(params, stream); |
| 239 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>>(params, stream); |
| 240 | + } else { |
| 241 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>>(params, stream); |
| 242 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream); |
| 243 | + } |
| 244 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream); |
| 245 | + |
| 246 | + // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream); |
| 247 | +} |
| 248 | + |
| 249 | +template<typename T, bool Is_causal> |
| 250 | +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 251 | + constexpr static int Headdim = 192; |
| 252 | + int device; |
| 253 | + cudaGetDevice(&device); |
| 254 | + int max_smem_per_block; |
| 255 | + cudaError status_ = cudaDeviceGetAttribute( |
| 256 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 257 | + if (status_ != cudaSuccess) { |
| 258 | + C10_CUDA_CHECK(status_); |
| 259 | + } |
| 260 | + if (max_smem_per_block >= 136 * 1024) { |
| 261 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream); |
| 262 | + } else { |
| 263 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream); |
| 264 | + } |
| 265 | +} |
| 266 | + |
| 267 | +template<typename T, bool Is_causal> |
| 268 | +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { |
| 269 | + constexpr static int Headdim = 256; |
| 270 | + int device; |
| 271 | + cudaGetDevice(&device); |
| 272 | + int max_smem_per_block; |
| 273 | + cudaError status_ = cudaDeviceGetAttribute( |
| 274 | + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
| 275 | + if (status_ != cudaSuccess) { |
| 276 | + C10_CUDA_CHECK(status_); |
| 277 | + } |
| 278 | + if (max_smem_per_block >= 176 * 1024) { // H100 |
| 279 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream); |
| 280 | + } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem |
| 281 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream); |
| 282 | + } else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering. |
| 283 | + run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream); |
| 284 | + } |
| 285 | +} |
| 286 | + |
| 287 | +} // namespace FLASH_NAMESPACE |
0 commit comments