1+ diff --git a/.gitmodules b/.gitmodules
2+ deleted file mode 100644
3+ index 8d501cb..0000000
4+ --- a/.gitmodules
5+ +++ /dev/null
6+ @@ -1,3 +0,0 @@
7+ - [submodule "csrc/cutlass"]
8+ - path = csrc/cutlass
9+ - url = https://github.com/NVIDIA/cutlass.git
110diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt
211new file mode 100644
3- index 0000000..cc91698
12+ index 0000000..5cb0691
413--- /dev/null
514+++ b/csrc/CMakeLists.txt
6- @@ -0,0 +1,117 @@
15+ @@ -0,0 +1,215 @@
716+ cmake_minimum_required(VERSION 3.18)
817+
918+ project(FLASHATTN LANGUAGES CXX CUDA)
@@ -21,10 +30,12 @@ index 0000000..cc91698
2130+ set(FLASHATTN_USE_CUDA_STATIC
2231+ OFF
2332+ CACHE BOOL "use static CUDA")
33+ +
2434+ # Generate SASS for each architecture
2535+ foreach(arch ${FLASHATTN_GPU_ARCHS})
2636+ list(APPEND GENCODES "${arch}-real")
2737+ endforeach()
38+ +
2839+ # Generate PTX for the last architecture
2940+ list(GET FLASHATTN_GPU_ARCHS -1 LATEST_GPU_ARCH)
3041+ list(APPEND GENCODES "${LATEST_GPU_ARCH}-virtual")
@@ -71,12 +82,108 @@ index 0000000..cc91698
7182+ set(FLASHATTN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/flash_attn/src
7283+ ${PROJECT_SOURCE_DIR})
7384+ file(GLOB_RECURSE FLASHATTN_SRCS ${FLASHATTN_ROOT}/*.cu)
85+ +
7486+ # no bwd
7587+ file(GLOB_RECURSE FLASHATTN_BWD_SRCS ${FLASHATTN_ROOT}/*_bwd_*.cu)
7688+ foreach(file ${FLASHATTN_BWD_SRCS})
7789+ list(REMOVE_ITEM FLASHATTN_SRCS "${file}")
7890+ endforeach()
7991+
92+ + #############
93+ +
94+ + # Define allowed HEADDIM values
95+ + set(ALLOWED_HEADDIMS_LIST "32;64;96;128;160;192;224;256" CACHE STRING "List of allowed HEADDIM values")
96+ + set(TARGET_HEADDIM_LIST "" CACHE STRING "List of target HEADDIM values (overrides ALLOWED_HEADDIMS_LIST)")
97+ +
98+ + # Use default values if the user hasn't specified TARGET_HEADDIM_LIST
99+ + if(TARGET_HEADDIM_LIST STREQUAL "")
100+ + set(TARGET_HEADDIM_LIST "${ALLOWED_HEADDIMS_LIST}")
101+ + endif()
102+ +
103+ + message(STATUS "ALLOWED_HEADDIMS_LIST: ${ALLOWED_HEADDIMS_LIST}")
104+ + message(STATUS "TARGET_HEADDIM_LIST: ${TARGET_HEADDIM_LIST}")
105+ +
106+ + # Validate that each value in TARGET_HEADDIM_LIST is in ALLOWED_HEADDIMS_LIST
107+ + foreach(dim IN LISTS TARGET_HEADDIM_LIST)
108+ + list(FIND ALLOWED_HEADDIMS_LIST ${dim} index)
109+ + if(${index} EQUAL -1)
110+ + message(FATAL_ERROR "Unsupported HEADDIM value: ${dim}")
111+ + endif()
112+ + endforeach()
113+ +
114+ + # Sort TARGET_HEADDIM_LIST to ensure ascending order
115+ + list(SORT TARGET_HEADDIM_LIST)
116+ +
117+ + # Generate the content of the HEADDIM_SWITCH macro
118+ + set(HEADDIM_SWITCH_CONTENT "")
119+ + string(APPEND HEADDIM_SWITCH_MACRO "/* Auto-generated HEADDIM dispatcher */\n")
120+ + set(HEADDIM_SWITCH_CONTENT " if (false) {} /* to allow else if */ \\\n")
121+ + foreach(dim IN LISTS TARGET_HEADDIM_LIST)
122+ + string(APPEND HEADDIM_SWITCH_CONTENT
123+ + " else if (HEADDIM == ${dim}) { \\\n"
124+ + " constexpr static int kHeadDim = ${dim}; \\\n"
125+ + " return __VA_ARGS__(); \\\n"
126+ + " } \\\n")
127+ + endforeach()
128+ +
129+ + # Add the final else statement for unsupported HEADDIM values
130+ + string(APPEND HEADDIM_SWITCH_CONTENT
131+ + " else { \\\n"
132+ + " throw std::runtime_error(\"Unsupported HEADDIM: \" + std::to_string(HEADDIM)); \\\n"
133+ + " } \\\n")
134+ +
135+ + # Generate the HEADDIM_SWITCH macro definition
136+ + set(HEADDIM_SWITCH_MACRO "")
137+ + string(APPEND HEADDIM_SWITCH_MACRO "/* Auto-generated HEADDIM dispatcher */\n")
138+ + string(APPEND HEADDIM_SWITCH_MACRO "#pragma once\n\n")
139+ + string(APPEND HEADDIM_SWITCH_MACRO "#include <string>\n\n")
140+ + string(APPEND HEADDIM_SWITCH_MACRO "#define HEADDIM_SWITCH(HEADDIM, ...) \\\n")
141+ + string(APPEND HEADDIM_SWITCH_MACRO " [&] { \\\n")
142+ + string(APPEND HEADDIM_SWITCH_MACRO "${HEADDIM_SWITCH_CONTENT}")
143+ + string(APPEND HEADDIM_SWITCH_MACRO " }()\n")
144+ +
145+ + # Generate a header file containing the HEADDIM_SWITCH macro
146+ + set(HEADDIM_SWITCH_HEADER "${FLASHATTN_ROOT}/headdim_switch.h")
147+ + file(WRITE ${HEADDIM_SWITCH_HEADER} "${HEADDIM_SWITCH_MACRO}\n")
148+ +
149+ + # Create an empty list to store the files to be kept
150+ + set(FILES_TO_KEEP)
151+ +
152+ + file(GLOB FLASH_FWD_FILES "${FLASHATTN_ROOT}/flash_fwd*.cu")
153+ +
154+ + # Iterate over all found .cu files
155+ + foreach(file_path IN LISTS FLASH_FWD_FILES)
156+ + # Get the file name
157+ + get_filename_component(file_name ${file_path} NAME)
158+ +
159+ + # Use a regex to extract the hdim value
160+ + # Assuming filename format: flash_fwd_*hdim<value>*.cu
161+ + string(REGEX MATCH "flash_fwd_.*hdim([0-9]+).*\\.cu" _ "${file_name}")
162+ +
163+ + if(NOT "${CMAKE_MATCH_1}" STREQUAL "")
164+ + set(file_hdim "${CMAKE_MATCH_1}")
165+ +
166+ + # Check if file_hdim is in TARGET_HEADDIM_LIST
167+ + list(FIND TARGET_HEADDIM_LIST ${file_hdim} index)
168+ + if(NOT ${index} EQUAL -1)
169+ + # Add to FILES_TO_KEEP
170+ + list(APPEND FILES_TO_KEEP "${file_path}")
171+ + message(STATUS "Including CUDA file: ${file_path} with hdim=${file_hdim}")
172+ + else()
173+ + # Exclude the file and log the information
174+ + message(STATUS "Excluding CUDA file: ${file_path} with hdim=${file_hdim} not in HEADDIM_LIST")
175+ + endif()
176+ + endif()
177+ + endforeach()
178+ +
179+ + list(REMOVE_ITEM FLASH_FWD_FILES ${FILES_TO_KEEP}) # remain files to remove
180+ + foreach(file ${FLASH_FWD_FILES})
181+ + list(REMOVE_ITEM FLASHATTN_SRCS "${file}")
182+ + endforeach()
183+ + message("Flash Attention build source list: ${FLASHATTN_SRCS}")
184+ +
185+ + ###############
186+ +
80187+ list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_OPERATORS__")
81188+ list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_CONVERSIONS__")
82189+ list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF2_OPERATORS__")
@@ -368,3 +475,22 @@ index eb8bcea..30c9a2e 100644
368475 }
369476 // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
370477 DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
478+ diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h
479+ index 20c2afd..01014a4 100644
480+ --- a/csrc/flash_attn/src/static_switch.h
481+ +++ b/csrc/flash_attn/src/static_switch.h
482+ @@ -87,6 +87,9 @@
483+ } \
484+ }()
485+
486+ + #include "headdim_switch.h"
487+ +
488+ + #ifndef HEADDIM_SWITCH
489+ #define HEADDIM_SWITCH(HEADDIM, ...) \
490+ [&] { \
491+ if (HEADDIM <= 32) { \
492+ @@ -115,3 +118,4 @@
493+ return __VA_ARGS__(); \
494+ } \
495+ }()
496+ + #endif
0 commit comments