Skip to content

Commit 3a0417b

Browse files
yejunjinyejunjin
andauthored
Add flash attention on intel-avx512 platform (#18)
* fix: tweak glm-4-9b-chat model * feat: add llama3 model * feat: add flash attention support at intel-avx512 platform --------- Co-authored-by: jinyejun.jyj <jinyejun.jyj@alibaba-inc.com>
1 parent 5ebd2c0 commit 3a0417b

File tree

25 files changed

+1096
-59
lines changed

25 files changed

+1096
-59
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(CONFIG_HOST_CPU_TYPE "X86" CACHE STRING "host cpu type, like X86, ARMV9, etc
3131

3232
## x86 related option.
3333
option(ENABLE_AVX2 "enable avx2" ON)
34+
option(ENABLE_AVX512 "enable avx512" ON)
3435

3536
## ARM related option.
3637
option(ENABLE_ARMCL "enable use of Arm Compute Library" OFF)
@@ -55,6 +56,10 @@ if(ENABLE_AVX2)
5556
list(APPEND ALLSPARK_DEFINITION "-DENABLE_AVX2")
5657
endif()
5758

59+
if(ENABLE_AVX512)
60+
list(APPEND ALLSPARK_DEFINITION "-DENABLE_AVX512")
61+
endif()
62+
5863
if(ENABLE_ARM_V84_V9)
5964
list(APPEND ALLSPARK_DEFINITION "-DENABLE_ARM_V84_V9")
6065
if (ENABLE_BF16)

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Written in C++ runtime, DashInfer aims to deliver production-level implementatio
2727
- **Support for Mainstream Open-Source LLMs**: DashInfer supports mainstream open-source LLMs, including Qwen, LLaMA, ChatGLM, etc., and supports loading models in the Huggingface format.
2828
- **Post Training Quantization (PTQ)**: Using DashInfer's InstantQuant (IQ), weight-only quantization acceleration can be achieved without fine-tuning, improving deployment efficiency. Accuracy evaluation shows that IQ has no impact on model accuracy. The current version supports weight-only 8-bit quantization on ARM CPUs.
2929
- **Optimized Computation Kernels**: With OneDNN and self-developed assembly kernels, DashInfer is able to maximize the performance of the hardware on both ARM and x86.
30+
- **Supports Flash Attention**: Significantly accelerates the attention computation for long sequences, drastically reducing the latency for the first-token.
3031
- **NUMA-Aware Design**: DashInfer supports tensor parallel inference across multiple NUMA nodes, fully leveraging the computing power of server CPUs. With numactl and a multi-process architecture, the NUMA affinity of threads is accurately controlled to maximize the performance of multi-node CPUs and avoid the performance degradation caused by cross-NUMA access. For more information on NUMA, see: [Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html).
3132
- **Context Length**: The current version supports up to 32k context length, with plans to extend to longer context lengths in the future.
3233
- **Multi-Language API Interfaces**: Both C++ and Python interfaces are supported. It is possible to extend C++ interface to Java, Rust and other programming languages, via standard cross-language interfaces.
@@ -88,6 +89,7 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi
8889
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
8990
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
9091
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
92+
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |
9193

9294
# Software Architecture
9395

@@ -187,7 +189,7 @@ This subsection lists the third-party dependencies for the different stages of D
187189

188190
# Future Plans
189191

190-
- [ ] Accelerate attention with Flash-Attention
192+
- [x] Accelerate attention with Flash-Attention
191193
- [x] Expand context length to over 32k
192194
- [ ] Support 4-bit quantization
193195
- [ ] Support quantized models fine-tuned with GPTQ

README_CN.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ DashInfer采用C++ Runtime编写,提供C++和Python语言接口。DashInfer具
2727
- **支持主流LLM开源模型**:支持主流的开源LLM模型,包括Qwen、LLaMA、ChatGLM等,支持Huggingface格式的模型读取。
2828
- **PTQ量化**:使用DashInfer的InstantQuant(IQ),无需训练微调即可实现weight-only量化加速,提高部署效率。经过精度测试,IQ对模型精度不会产生影响。目前版本支持ARM CPU上的weight-only 8-bit量化。
2929
- **优化的计算Kernel**:结合OneDNN和自研汇编kernel,DashInfer能够在ARM和x86上发挥硬件的最大性能。
30+
- **支持Flash Attention**:显著加速了长序列的Attention计算过程,大大降低首包延迟。
3031
- **NUMA-Aware**:支持多NUMA的tensor并行推理,充分发挥服务器级CPU的算力。通过numactl和多进程架构,精准控制计算线程的NUMA亲和性,充分利用多节点CPU的性能,并且避免跨NUMA访存带来性能下降问题。关于多NUMA的性能指导可以参考:[Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html)
3132
- **上下文长度(Context Length)**:目前版本支持32k的Context Length,未来还会继续支持更长Context Length。
3233
- **提供多语言API接口**:提供C++和Python接口,能够直接使用C++接口对接到Java、Rust等其他编程语言。
@@ -89,6 +90,7 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
8990
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
9091
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
9192
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
93+
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |
9294

9395
# 软件框架
9496

@@ -188,12 +190,13 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
188190

189191
# 未来规划
190192

191-
- [ ] 首包加速:加入CPU实现的Flash-Attention等Attention加速技术;
193+
- [x] 首包加速:加入CPU实现的Flash-Attention等Attention加速技术;
192194
- [x] Context Length:扩展到32k以上;
193195
- [ ] 低bit量化支持:支持4-bit量化;
194196
- [ ] QAT量化支持:支持GPTQ算法量化微调过的模型;
195197
- [ ] MoE:支持MoE模型和架构。
196198

199+
197200
# License
198201

199202
DashInfer源代码采用Apache 2.0协议授权,您可在该仓库根目录找到协议全文。

build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ elif [ "${with_platform,,}" == "armclang" ]; then
7373
-DBUILD_PACKAGE=${build_package} \
7474
-DALLSPARK_CBLAS=BLIS \
7575
-DENABLE_AVX2=OFF \
76+
-DENABLE_AVX512=OFF \
7677
-DENABLE_ARM_V84_V9=ON \
7778
-DENABLE_BF16=ON \
7879
-DENABLE_FP16=ON \

csrc/common/env_config.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,15 @@ class EnvVarConfig {
5959
}
6060
};
6161

62+
class AttentionEnvConfig {
63+
public:
64+
static int GetFlashThresh() {
65+
static int env_flash_thresh = -1;
66+
if (env_flash_thresh == -1) {
67+
env_flash_thresh = EnvVarConfig::GetInt("AS_FLASH_THRESH", 1024);
68+
}
69+
return env_flash_thresh;
70+
}
71+
};
72+
6273
} // namespace allspark

csrc/core/kernel/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ file(
4141
cpu/rotary.cpp
4242
)
4343

44+
file(
45+
GLOB_RECURSE
46+
src_avx512
47+
cpu/mha.cpp
48+
)
49+
4450
file(
4551
GLOB_RECURSE
4652
src_arm
@@ -53,6 +59,14 @@ if(ENABLE_AVX2)
5359
set_source_files_properties(${src_avx2} PROPERTIES COMPILE_FLAGS "${AVX2_FLAGS}")
5460
endif(ENABLE_AVX2)
5561

62+
if(ENABLE_AVX512)
63+
set(AVX512_FLAGS "-mavx512f -mavx512bw -mavx512vl")
64+
message("AVX512 flags: ${AVX512_FLAGS}, files: ${src_avx512}")
65+
get_source_file_property(OTHER_FLAGS ${src_avx512} COMPILE_FLAGS)
66+
message("APPEND flags: ${OTHER_FLAGS}, files: ${src_avx512}")
67+
set_source_files_properties(${src_avx512} PROPERTIES COMPILE_FLAGS "${OTHER_FLAGS} ${AVX512_FLAGS}")
68+
endif(ENABLE_AVX512)
69+
5670
if(NOT ENABLE_ARM_V84_V9)
5771
foreach(file ${src_arm})
5872
list(REMOVE_ITEM src_cpu_common "${file}")

csrc/core/kernel/cpu/cpu_kernel.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ void EmbeddingT5KernelLauncher(T* out_tensor, const int64_t* word_ids,
2525
const T* embedding_table, int batch_size,
2626
int seq_len, int hidden_size, int vocab_size,
2727
bool use_decoder);
28+
29+
template <typename T>
30+
void SelfScaledDpAttention(T* output, const T* query, const T* key,
31+
const T* value, int q_num_heads, int kv_num_heads,
32+
int size_per_head, int o_stride, int q_stride,
33+
int kv_stride, int batch_size,
34+
const int* input_seq_lens, const int* past_seq_lens,
35+
void* workspace, int src_blk, int tgt_blk,
36+
const float* mask, float scale, int num_thread);
37+
2838
template <typename T>
2939
void GetBatchArrayLauncher(T* q, T* k, T* v, T* score, T* out, T** q_array,
3040
T** k_array, T** v_array, T** score_array,

0 commit comments

Comments
 (0)