Skip to content

Commit 666718f

Browse files
authored
flash-attn: only build headdim 128 (#62)
Reduce package size by only compiling flash-attn src with hdim128
1 parent 76484ff commit 666718f

File tree

3 files changed

+137
-3
lines changed

3 files changed

+137
-3
lines changed

cmake/flash-attention.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ set(FLASHATTN_USE_CUDA_STATIC
1919
CACHE BOOL "flash-attn use static CUDA")
2020

2121
set(FLASHATTN_USE_STATIC_LIB
22-
OFF
22+
ON
2323
CACHE BOOL "use flash-attn static lib")
2424

25+
set(TARGET_HEADDIM_LIST "128" CACHE STRING "List of target HEADDIM values (overrides ALLOWED_HEADDIMS_LIST)")
26+
2527
# only static link when needed, to reduce size.
2628
if(ENABLE_NV_STATIC_LIB)
2729
set(FLASHATTN_USE_CUDA_STATIC ON)
@@ -69,6 +71,7 @@ include(ExternalProject)
6971
-DFLASHATTN_USE_CUDA_STATIC=${FLASHATTN_USE_CUDA_STATIC}
7072
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL}
7173
-DCUTLASS_INSTALL_PATH=${CUTLASS_INSTALL}
74+
-DTARGET_HEADDIM_LIST=${TARGET_HEADDIM_LIST}
7275
)
7376

7477
ExternalProject_Get_Property(project_flashattn SOURCE_DIR)
@@ -93,6 +96,7 @@ set_target_properties(flash-attention::flash-attn PROPERTIES
9396
include_directories(${FLASHATTN_INCLUDE_DIR})
9497
set(FLASHATTN_LIBRARY flash-attention::flash-attn)
9598

99+
unset(TARGET_HEADDIM_LIST)
96100
unset(FLASHATTN_CUDA_VERSION)
97101
unset(FLASHATTN_GPU_ARCHS)
98102
unset(FLASHATTN_USE_EXTERNAL_CUTLASS)

csrc/utility/format_enforcer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
#include <lmfe/jsonschemaparser.hpp>
1414
#include <lmfe/tokenenforcer.hpp>
1515

16+
#ifdef ENABLE_CUDA
17+
#include <cuda_runtime.h>
18+
#endif
19+
1620
namespace allspark {
1721
namespace util {
1822

third_party/patch/flash-attn.patch

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
diff --git a/.gitmodules b/.gitmodules
2+
deleted file mode 100644
3+
index 8d501cb..0000000
4+
--- a/.gitmodules
5+
+++ /dev/null
6+
@@ -1,3 +0,0 @@
7+
-[submodule "csrc/cutlass"]
8+
- path = csrc/cutlass
9+
- url = https://github.com/NVIDIA/cutlass.git
110
diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt
211
new file mode 100644
3-
index 0000000..cc91698
12+
index 0000000..5cb0691
413
--- /dev/null
514
+++ b/csrc/CMakeLists.txt
6-
@@ -0,0 +1,117 @@
15+
@@ -0,0 +1,215 @@
716
+cmake_minimum_required(VERSION 3.18)
817
+
918
+project(FLASHATTN LANGUAGES CXX CUDA)
@@ -21,10 +30,12 @@ index 0000000..cc91698
2130
+set(FLASHATTN_USE_CUDA_STATIC
2231
+ OFF
2332
+ CACHE BOOL "use static CUDA")
33+
+
2434
+# Generate SASS for each architecture
2535
+foreach(arch ${FLASHATTN_GPU_ARCHS})
2636
+ list(APPEND GENCODES "${arch}-real")
2737
+endforeach()
38+
+
2839
+# Generate PTX for the last architecture
2940
+list(GET FLASHATTN_GPU_ARCHS -1 LATEST_GPU_ARCH)
3041
+list(APPEND GENCODES "${LATEST_GPU_ARCH}-virtual")
@@ -71,12 +82,108 @@ index 0000000..cc91698
7182
+set(FLASHATTN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/flash_attn/src
7283
+ ${PROJECT_SOURCE_DIR})
7384
+file(GLOB_RECURSE FLASHATTN_SRCS ${FLASHATTN_ROOT}/*.cu)
85+
+
7486
+# no bwd
7587
+file(GLOB_RECURSE FLASHATTN_BWD_SRCS ${FLASHATTN_ROOT}/*_bwd_*.cu)
7688
+foreach(file ${FLASHATTN_BWD_SRCS})
7789
+ list(REMOVE_ITEM FLASHATTN_SRCS "${file}")
7890
+endforeach()
7991
+
92+
+#############
93+
+
94+
+# Define allowed HEADDIM values
95+
+set(ALLOWED_HEADDIMS_LIST "32;64;96;128;160;192;224;256" CACHE STRING "List of allowed HEADDIM values")
96+
+set(TARGET_HEADDIM_LIST "" CACHE STRING "List of target HEADDIM values (overrides ALLOWED_HEADDIMS_LIST)")
97+
+
98+
+# Use default values if the user hasn't specified TARGET_HEADDIM_LIST
99+
+if(TARGET_HEADDIM_LIST STREQUAL "")
100+
+ set(TARGET_HEADDIM_LIST "${ALLOWED_HEADDIMS_LIST}")
101+
+endif()
102+
+
103+
+message(STATUS "ALLOWED_HEADDIMS_LIST: ${ALLOWED_HEADDIMS_LIST}")
104+
+message(STATUS "TARGET_HEADDIM_LIST: ${TARGET_HEADDIM_LIST}")
105+
+
106+
+# Validate that each value in TARGET_HEADDIM_LIST is in ALLOWED_HEADDIMS_LIST
107+
+foreach(dim IN LISTS TARGET_HEADDIM_LIST)
108+
+ list(FIND ALLOWED_HEADDIMS_LIST ${dim} index)
109+
+ if(${index} EQUAL -1)
110+
+ message(FATAL_ERROR "Unsupported HEADDIM value: ${dim}")
111+
+ endif()
112+
+endforeach()
113+
+
114+
+# Sort TARGET_HEADDIM_LIST to ensure ascending order
115+
+list(SORT TARGET_HEADDIM_LIST)
116+
+
117+
+# Generate the content of the HEADDIM_SWITCH macro
118+
+set(HEADDIM_SWITCH_CONTENT "")
119+
+string(APPEND HEADDIM_SWITCH_MACRO "/* Auto-generated HEADDIM dispatcher */\n")
120+
+set(HEADDIM_SWITCH_CONTENT " if (false) {} /* to allow else if */ \\\n")
121+
+foreach(dim IN LISTS TARGET_HEADDIM_LIST)
122+
+ string(APPEND HEADDIM_SWITCH_CONTENT
123+
+ " else if (HEADDIM == ${dim}) { \\\n"
124+
+ " constexpr static int kHeadDim = ${dim}; \\\n"
125+
+ " return __VA_ARGS__(); \\\n"
126+
+ " } \\\n")
127+
+endforeach()
128+
+
129+
+# Add the final else statement for unsupported HEADDIM values
130+
+string(APPEND HEADDIM_SWITCH_CONTENT
131+
+ " else { \\\n"
132+
+ " throw std::runtime_error(\"Unsupported HEADDIM: \" + std::to_string(HEADDIM)); \\\n"
133+
+ " } \\\n")
134+
+
135+
+# Generate the HEADDIM_SWITCH macro definition
136+
+set(HEADDIM_SWITCH_MACRO "")
137+
+string(APPEND HEADDIM_SWITCH_MACRO "/* Auto-generated HEADDIM dispatcher */\n")
138+
+string(APPEND HEADDIM_SWITCH_MACRO "#pragma once\n\n")
139+
+string(APPEND HEADDIM_SWITCH_MACRO "#include <string>\n\n")
140+
+string(APPEND HEADDIM_SWITCH_MACRO "#define HEADDIM_SWITCH(HEADDIM, ...) \\\n")
141+
+string(APPEND HEADDIM_SWITCH_MACRO " [&] { \\\n")
142+
+string(APPEND HEADDIM_SWITCH_MACRO "${HEADDIM_SWITCH_CONTENT}")
143+
+string(APPEND HEADDIM_SWITCH_MACRO " }()\n")
144+
+
145+
+# Generate a header file containing the HEADDIM_SWITCH macro
146+
+set(HEADDIM_SWITCH_HEADER "${FLASHATTN_ROOT}/headdim_switch.h")
147+
+file(WRITE ${HEADDIM_SWITCH_HEADER} "${HEADDIM_SWITCH_MACRO}\n")
148+
+
149+
+# Create an empty list to store the files to be kept
150+
+set(FILES_TO_KEEP)
151+
+
152+
+file(GLOB FLASH_FWD_FILES "${FLASHATTN_ROOT}/flash_fwd*.cu")
153+
+
154+
+# Iterate over all found .cu files
155+
+foreach(file_path IN LISTS FLASH_FWD_FILES)
156+
+ # Get the file name
157+
+ get_filename_component(file_name ${file_path} NAME)
158+
+
159+
+ # Use a regex to extract the hdim value
160+
+ # Assuming filename format: flash_fwd_*hdim<value>*.cu
161+
+ string(REGEX MATCH "flash_fwd_.*hdim([0-9]+).*\\.cu" _ "${file_name}")
162+
+
163+
+ if(NOT "${CMAKE_MATCH_1}" STREQUAL "")
164+
+ set(file_hdim "${CMAKE_MATCH_1}")
165+
+
166+
+ # Check if file_hdim is in TARGET_HEADDIM_LIST
167+
+ list(FIND TARGET_HEADDIM_LIST ${file_hdim} index)
168+
+ if(NOT ${index} EQUAL -1)
169+
+ # Add to FILES_TO_KEEP
170+
+ list(APPEND FILES_TO_KEEP "${file_path}")
171+
+ message(STATUS "Including CUDA file: ${file_path} with hdim=${file_hdim}")
172+
+ else()
173+
+ # Exclude the file and log the information
174+
+ message(STATUS "Excluding CUDA file: ${file_path} with hdim=${file_hdim} not in HEADDIM_LIST")
175+
+ endif()
176+
+ endif()
177+
+endforeach()
178+
+
179+
+list(REMOVE_ITEM FLASH_FWD_FILES ${FILES_TO_KEEP}) # remain files to remove
180+
+foreach(file ${FLASH_FWD_FILES})
181+
+ list(REMOVE_ITEM FLASHATTN_SRCS "${file}")
182+
+endforeach()
183+
+message("Flash Attention build source list: ${FLASHATTN_SRCS}")
184+
+
185+
+###############
186+
+
80187
+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_OPERATORS__")
81188
+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_CONVERSIONS__")
82189
+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF2_OPERATORS__")
@@ -368,3 +475,22 @@ index eb8bcea..30c9a2e 100644
368475
}
369476
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
370477
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
478+
diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h
479+
index 20c2afd..01014a4 100644
480+
--- a/csrc/flash_attn/src/static_switch.h
481+
+++ b/csrc/flash_attn/src/static_switch.h
482+
@@ -87,6 +87,9 @@
483+
} \
484+
}()
485+
486+
+#include "headdim_switch.h"
487+
+
488+
+#ifndef HEADDIM_SWITCH
489+
#define HEADDIM_SWITCH(HEADDIM, ...) \
490+
[&] { \
491+
if (HEADDIM <= 32) { \
492+
@@ -115,3 +118,4 @@
493+
return __VA_ARGS__(); \
494+
} \
495+
}()
496+
+#endif

0 commit comments

Comments
 (0)