Skip to content

Commit 82c0373

Browse files
author
yejunjin
authored
fix: change to size_t to avoid overflow when seq is long (#11)
1 parent f940c2c commit 82c0373

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Written in C++ runtime, DashInfer aims to deliver production-level implementatio
2828
- **Post Training Quantization (PTQ)**: Using DashInfer's InstantQuant (IQ), weight-only quantization acceleration can be achieved without fine-tuning, improving deployment efficiency. Accuracy evaluation shows that IQ has no impact on model accuracy. The current version supports weight-only 8-bit quantization on ARM CPUs.
2929
- **Optimized Computation Kernels**: With OneDNN and self-developed assembly kernels, DashInfer is able to maximize the performance of the hardware on both ARM and x86.
3030
- **NUMA-Aware Design**: DashInfer supports tensor parallel inference across multiple NUMA nodes, fully leveraging the computing power of server CPUs. With numactl and a multi-process architecture, the NUMA affinity of threads is accurately controlled to maximize the performance of multi-node CPUs and avoid the performance degradation caused by cross-NUMA access. For more information on NUMA, see: [Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html).
31-
- **Context Length**: The current version supports up to 11k context length, with plans to extend to longer context lengths in the future.
31+
- **Context Length**: The current version supports up to 32k context length, with plans to extend to longer context lengths in the future.
3232
- **Multi-Language API Interfaces**: Both C++ and Python interfaces are supported. It is possible to extend C++ interface to Java, Rust and other programming languages, via standard cross-language interfaces.
3333
- **Operating System Support**: DashInfer supports mainstream Linux server operating systems like Centos7 and Ubuntu22.04, and provides corresponding Docker images.
3434

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ DashInfer采用C++ Runtime编写,提供C++和Python语言接口。DashInfer具
2828
- **PTQ量化**:使用DashInfer的InstantQuant(IQ),无需训练微调即可实现weight-only量化加速,提高部署效率。经过精度测试,IQ对模型精度不会产生影响。目前版本支持ARM CPU上的weight-only 8-bit量化。
2929
- **优化的计算Kernel**:结合OneDNN和自研汇编kernel,DashInfer能够在ARM和x86上发挥硬件的最大性能。
3030
- **NUMA-Aware**:支持多NUMA的tensor并行推理,充分发挥服务器级CPU的算力。通过numactl和多进程架构,精准控制计算线程的NUMA亲和性,充分利用多节点CPU的性能,并且避免跨NUMA访存带来性能下降问题。关于多NUMA的性能指导可以参考:[Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html)
31-
- **上下文长度(Context Length)**目前版本支持11k的Context Length,未来还会继续支持更长Context Length。
31+
- **上下文长度(Context Length)**目前版本支持32k的Context Length,未来还会继续支持更长Context Length。
3232
- **提供多语言API接口**:提供C++和Python接口,能够直接使用C++接口对接到Java、Rust等其他编程语言。
3333
- **操作系统支持**:支持Centos7、Ubuntu22.04等主流Linux服务器操作系统,并提供对应的Docker镜像。
3434

csrc/core/kernel/cpu/mha.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ void MHAKernel(float* out, const float* q, const float* k, const float* v,
607607
size_per_head, alpha, q + qkv_offset, qkv_stride,
608608
k + qkv_offset, qkv_stride, beta, score_buf, step);
609609
for (int j = 0; j < seq_length; ++j) {
610-
int k = i * seq_length + j;
610+
size_t k = i * seq_length + j;
611611
vSoftmaxMask(step, score + k * step,
612612
mask + (m * seq_length + j) * step);
613613
}
@@ -627,7 +627,7 @@ void MHAKernel(float* out, const float* q, const float* k, const float* v,
627627
size_per_head, alpha, q + qkv_offset, qkv_stride,
628628
k + qkv_offset, qkv_stride, beta, score_buf, step);
629629
for (int j = 0; j < seq_length; ++j) {
630-
int k = i * seq_length + j;
630+
size_t k = i * seq_length + j;
631631
vSoftmax(step, score + k * step);
632632
}
633633
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_length,
@@ -656,7 +656,7 @@ void MHAKernel(float* out, const float* q, const float* k, const float* v,
656656
size_per_head, alpha, q + q_offset, q_stride, k + kv_offset,
657657
kv_stride, beta, score_buf, step * num_heads);
658658
for (int j = 0; j < seq_length; ++j) {
659-
int k = m * seq_length * num_heads + j * num_heads + n;
659+
size_t k = m * seq_length * num_heads + j * num_heads + n;
660660
vSoftmaxMask(step, score + k * step,
661661
mask + (m / beam_size * step + j) * step);
662662
}
@@ -676,7 +676,7 @@ void MHAKernel(float* out, const float* q, const float* k, const float* v,
676676
size_per_head, alpha, q + q_offset, q_stride, k + kv_offset,
677677
kv_stride, beta, score_buf, step * num_heads);
678678
for (int j = 0; j < seq_length; ++j) {
679-
int k = m * seq_length * num_heads + j * num_heads + n;
679+
size_t k = m * seq_length * num_heads + j * num_heads + n;
680680
;
681681
vSoftmax(step, score + k * step);
682682
}
@@ -782,7 +782,7 @@ void BatchSoftmax<float>(float* score, const float* mask, int batch_size,
782782
int m = i / num_heads;
783783
int n = i % num_heads;
784784
parallel_for(seq_len, [&](int j) {
785-
int k = m * seq_len * num_heads + j * num_heads + n;
785+
size_t k = m * seq_len * num_heads + j * num_heads + n;
786786
vSoftmaxMask(step, score + k * step,
787787
mask + (m / beam_size * step + j) * step);
788788
});
@@ -792,7 +792,7 @@ void BatchSoftmax<float>(float* score, const float* mask, int batch_size,
792792
int m = i / num_heads;
793793
int n = i % num_heads;
794794
parallel_for(seq_len, [&](int j) {
795-
int k = m * seq_len * num_heads + j * num_heads + n;
795+
size_t k = m * seq_len * num_heads + j * num_heads + n;
796796
vSoftmax(step, score + k * step);
797797
});
798798
});
@@ -813,7 +813,7 @@ void BatchDecoderSoftmax<float>(float* score, const float* mask, int batch_size,
813813
mask + (m / beam_size * input_len + input_len - 1) * input_len,
814814
input_len * sizeof(float));
815815
parallel_for(seq_len, [&](int j) {
816-
int k = m * seq_len * num_heads + j * num_heads + n;
816+
size_t k = m * seq_len * num_heads + j * num_heads + n;
817817
vSoftmaxMask(step, score + k * step, mask_in.data());
818818
});
819819
});

0 commit comments

Comments
 (0)