Skip to content

Commit afa12c3

Browse files
authored
Support Baichuan-7B and Baichuan2-7B & 13B (#38)
1 parent 2be2060 commit afa12c3

File tree

21 files changed

+1090
-54
lines changed

21 files changed

+1090
-54
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi
9090
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
9191
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
9292
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |
93+
| BaiChuanForCausalLM | Baichuan | Baichuan | [baichuan-inc/Baichuan-7B](https://huggingface.co/baichuan-inc/Baichuan-7B) | [baichuan-inc/baichuan-7B](https://modelscope.cn/models/baichuan-inc/baichuan-7B) | / |
94+
| BaichuanForCausalLM | Baichuan2 | Baichuan_v2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat), <br>[baichuan-inc/Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat) | [baichuan-inc/Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-7B-Chat), <br>[baichuan-inc/Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat) | / |
9395

9496
# Software Architecture
9597

README_CN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
9191
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
9292
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
9393
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |
94+
| BaiChuanForCausalLM | Baichuan | Baichuan | [baichuan-inc/Baichuan-7B](https://huggingface.co/baichuan-inc/Baichuan-7B) | [baichuan-inc/baichuan-7B](https://modelscope.cn/models/baichuan-inc/baichuan-7B) | / |
95+
| BaichuanForCausalLM | Baichuan2 | Baichuan_v2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat), <br>[baichuan-inc/Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat) | [baichuan-inc/Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-7B-Chat), <br>[baichuan-inc/Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat) | / |
9496

9597
# 软件框架
9698

build.sh

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,12 @@ set -x
22

33
clean="OFF"
44

5-
# 捕获arch命令的输出
6-
architecture=$(arch)
7-
8-
# 使用if-else结构进行条件判断
9-
if [ "${architecture}" == "aarch64" ]; then
10-
export AS_PLATFORM=armclang
11-
else
12-
export AS_PLATFORM=x86
13-
fi
14-
15-
if [ -z "$AS_PLATFORM" ];
16-
then
17-
echo " please set AS_PLATFORM env, AS_PLATFORM can be x86 or armclang"
18-
exit 1
19-
fi
20-
215
# with_platform, to support x86/arm build
22-
with_platform="${AS_PLATFORM}"
6+
with_platform="${AS_PLATFORM:-x86}"
237
build_type="${AS_BUILD_TYPE:-Release}"
248
build_package="${AS_BUILD_PACKAGE:-OFF}"
259
enable_glibcxx11_abi="${AS_CXX11_ABI:-ON}" # default enable cxx11 ABI
2610

27-
2811
function clone_pull {
2912
GIT_URL=$1
3013
DIRECTORY=$2

csrc/core/kernel/cpu/ALiBiPE.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void ALiBiPE_kernel(T* out, int* batch_offset, int batch_size, int seq_length,
3939
parallel_for(N, [&](int idx) {
4040
int batch = idx / num_heads;
4141
int head = idx % num_heads;
42-
int offset = batch_offset[batch];
42+
int offset = batch_offset ? batch_offset[batch] : 0;
4343
float slope = get_ALiBiPE_slope(head, num_heads, ori_num_heads, rank);
4444
for (int i = 0; i < seq_length; i++) {
4545
for (int j = 0; j < seq_length; j++) {
@@ -53,42 +53,49 @@ void ALiBiPE_kernel(T* out, int* batch_offset, int batch_size, int seq_length,
5353
template <typename T>
5454
void ALiBiPE_decoder_kernel(T* out, int* batch_offset, int batch_size,
5555
int seq_length, int num_heads, int ori_num_heads,
56-
int rank, int N) {
56+
int rank, int N, std::vector<int>& step_list) {
5757
// return [batch,1,num_heads,seq_length],i=seq_length-1
5858
parallel_for(N, [&](int idx) {
5959
int batch = idx / num_heads;
6060
int head = idx % num_heads;
61-
int offset = batch_offset[batch];
61+
int step = step_list[batch];
62+
int offset = batch_offset ? batch_offset[batch] : 0;
6263
float slope = get_ALiBiPE_slope(head, num_heads, ori_num_heads, rank);
63-
for (int j = 0; j < seq_length; j++) {
64-
out[batch * num_heads * 1 * seq_length + head * seq_length + j] =
64+
for (int j = 0; j < step; j++) {
65+
out[batch * num_heads * 1 * seq_length + head * step + j] =
6566
slope * (j - offset);
67+
// we take this output tensor as a one-dimensional array in batch_MHA
68+
// afterwards so it's 'head * step', not 'head * seq_length' otherwise the
69+
// values updated will not be consecutively stored
6670
}
6771
});
6872
}
6973
template <typename T>
7074
void ALiBiPEKernelLauncher(T* out, int* batch_offset, int batch_size,
7175
int seq_length, int num_heads, int ori_num_heads,
72-
int step, int rank) {
76+
int rank, bool is_context,
77+
std::vector<int>& step_list) {
7378
int N = batch_size * num_heads;
74-
if (step - 1 == 0) {
79+
if (is_context == true) {
7580
ALiBiPE_kernel(out, batch_offset, batch_size, seq_length, num_heads,
7681
ori_num_heads, rank, N);
7782
} else {
78-
ALiBiPE_decoder_kernel(out, batch_offset, batch_size, step, num_heads,
79-
ori_num_heads, rank, N);
83+
ALiBiPE_decoder_kernel(out, batch_offset, batch_size, seq_length, num_heads,
84+
ori_num_heads, rank, N, step_list);
8085
}
8186
}
8287

8388
template void ALiBiPEKernelLauncher<float>(float* out, int* batch_offset,
8489
int batch_size, int seq_length,
8590
int num_heads, int ori_num_heads,
86-
int step, int rank);
91+
int rank, bool is_context,
92+
std::vector<int>& step_list);
8793
#ifdef ENABLE_FP16
8894
template void ALiBiPEKernelLauncher<half>(half* out, int* batch_offset,
8995
int batch_size, int seq_length,
9096
int num_heads, int ori_num_heads,
91-
int step, int rank);
97+
int rank, bool is_context,
98+
std::vector<int>& step_list);
9299
#endif
93100
} // namespace cpu
94101
} // namespace allspark

csrc/core/kernel/cpu/cpu_kernel.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <common.h>
88
#include <stdint.h>
99

10+
#include <map>
1011
#include <vector>
1112
namespace allspark {
1213
namespace cpu {
@@ -90,7 +91,8 @@ void RelativePEKernel(T* out, const T* attention_bias, int batch_size,
9091
template <typename T>
9192
void ALiBiPEKernelLauncher(T* out, int* batch_offset, int batch_size,
9293
int seq_length, int num_heads, int ori_num_heads,
93-
int step, int rank);
94+
int rank, bool is_context,
95+
std::vector<int>& step_list);
9496
template <typename T>
9597
void MHAKernel(T* out, const T* q, const T* k, const T* v, const float* mask,
9698
T* score, int beam_size, int batch_size, int num_heads,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*!
2+
* Copyright (c) Alibaba, Inc. and its affiliates.
3+
* @file baichuan.cpp
4+
*/
5+
6+
#include "baichuan.h" // NOLINT
7+
8+
namespace allspark {
9+
AsStatus BaichuanModel::Init(const TransformerProto& model_proto,
10+
const DeviceContext& ctx) {
11+
DLOG(INFO) << "BaichuanModel::Init()" << std::endl;
12+
AS_CHECK_STATUS(AsModel::Init(model_proto, ctx));
13+
topo_ops_.clear();
14+
// parse graph
15+
for (auto& op : graph_ops_["decoder"]) {
16+
topo_ops_.emplace_back(op.get());
17+
}
18+
if (model_proto.model_conf().is_generate())
19+
for (auto& op : graph_ops_["gen_graph"]) {
20+
topo_ops_.emplace_back(op.get());
21+
}
22+
return AsStatus::ALLSPARK_SUCCESS;
23+
}
24+
REGISTER_MODEL("Baichuan_v2", BaichuanModel_v2)
25+
REGISTER_MODEL("Baichuan", BaichuanModel)
26+
} // namespace allspark
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*!
2+
* Copyright (c) Alibaba, Inc. and its affiliates.
3+
* @file baichuan.h
4+
*/
5+
6+
#pragma once
7+
8+
#include <core/model/model.h>
9+
10+
#include <string>
11+
12+
namespace allspark {
13+
14+
class BaichuanModel : public AsModel {
15+
public:
16+
explicit BaichuanModel(const std::string& model_type = "")
17+
: AsModel(model_type) {}
18+
AsStatus Init(const TransformerProto& model_proto,
19+
const DeviceContext& ctx) override;
20+
};
21+
22+
class BaichuanModel_v2 : public BaichuanModel {
23+
public:
24+
explicit BaichuanModel_v2(const std::string& model_type = "")
25+
: BaichuanModel(model_type){};
26+
};
27+
28+
29+
30+
} // namespace allspark

csrc/core/operator/general/ALiBiPE/ALiBiPE_op.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111

1212
namespace allspark {
1313
AsStatus cpu_ALiBiPE(DataType dtype, void* out, int* batch_offset, int batch,
14-
int seq_len, int num_heads, int ori_num_heads, int step,
15-
const DeviceContext* ctx) {
14+
int seq_len, int num_heads, int ori_num_heads,
15+
const DeviceContext* ctx, bool is_context,
16+
std::vector<int>& step_list) {
1617
DLOG(INFO) << "cpu_ALiBiPE" << std::endl;
1718
const CPUContext* cpu_ctx = static_cast<const CPUContext*>(ctx);
1819
auto functor = [&]<typename T>() {
1920
T* typed_out = static_cast<T*>(out);
2021
cpu::ALiBiPEKernelLauncher(typed_out, batch_offset, batch, seq_len,
21-
num_heads, ori_num_heads, step,
22-
cpu_ctx->GetRank());
22+
num_heads, ori_num_heads, cpu_ctx->GetRank(),
23+
is_context, step_list);
2324
};
2425
DispatchCPU(dtype, functor);
2526
return AsStatus::ALLSPARK_SUCCESS;
@@ -54,32 +55,56 @@ AsStatus ALiBiPEOp::Init(const OperatorProto& op_proto,
5455
return AsStatus::ALLSPARK_SUCCESS;
5556
}
5657

57-
AsStatus ALiBiPEOp::Reshape() {
58+
AsStatus ALiBiPEOp::Reshape(RuntimeContext* runtime_ctx) {
5859
Shape in_shape = tensor_map_->at(in_names_[0])->GetShape();
59-
batch_size_ = in_shape[0];
60-
if (gen_ctx_->step == 0) {
60+
if (runtime_ctx->is_context == true) {
61+
batch_size_ = in_shape[0];
6162
seq_length_ = in_shape[1];
6263
Shape out_shape = Shape{batch_size_, seq_length_, num_heads_, seq_length_};
6364
AS_CHECK_STATUS(
6465
tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape)));
65-
} else {
66-
seq_length_ = 1;
6766
}
67+
return AsStatus::ALLSPARK_SUCCESS;
68+
}
6869

70+
AsStatus ALiBiPEOp::runContext(RuntimeContext* runtime_ctx) {
71+
int* batch_offset = nullptr;
72+
AsTensor* out_tensor = tensor_map_->at(out_names_[0]).get();
73+
std::vector<int> step_list;
74+
kernel_launcher(out_tensor->GetDataType(), out_tensor->GetDataPtr(),
75+
batch_offset, batch_size_, seq_length_, num_heads_,
76+
ori_num_heads_, ctx_, runtime_ctx->is_context, step_list);
6977
return AsStatus::ALLSPARK_SUCCESS;
7078
}
7179

72-
AsStatus ALiBiPEOp::Forward() {
80+
AsStatus ALiBiPEOp::runDecode(RuntimeContext* runtime_ctx) {
7381
int* batch_offset = nullptr;
7482
AsTensor* out_tensor = tensor_map_->at(out_names_[0]).get();
75-
if (gen_ctx_->step != 0) {
76-
Shape out_shape = Shape{batch_size_, 1, num_heads_, gen_ctx_->step + 1};
77-
AS_CHECK_STATUS(
78-
tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape)));
83+
int batch_size = runtime_ctx->GetGenCtxListSize();
84+
std::vector<int> step_list(batch_size);
85+
int max_step = 1;
86+
for (int i = 0; i < batch_size; i++) {
87+
GenerateContext* gen_ctx = runtime_ctx->GetGenCtx(i);
88+
if (gen_ctx->step + 1 > max_step) {
89+
max_step = gen_ctx->step + 1;
90+
}
91+
step_list[i] = gen_ctx->step + 1;
7992
}
93+
Shape out_shape = Shape{batch_size, 1, num_heads_, max_step};
94+
AS_CHECK_STATUS(
95+
tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape)));
8096
kernel_launcher(out_tensor->GetDataType(), out_tensor->GetDataPtr(),
81-
batch_offset, batch_size_, seq_length_, num_heads_,
82-
ori_num_heads_, (gen_ctx_->step + 1), ctx_);
97+
batch_offset, batch_size, max_step, num_heads_,
98+
ori_num_heads_, ctx_, runtime_ctx->is_context, step_list);
99+
return AsStatus::ALLSPARK_SUCCESS;
100+
}
101+
102+
AsStatus ALiBiPEOp::Forward(RuntimeContext* runtime_ctx) {
103+
if (runtime_ctx->is_context == true) {
104+
runContext(runtime_ctx);
105+
} else {
106+
runDecode(runtime_ctx);
107+
}
83108
return AsStatus::ALLSPARK_SUCCESS;
84109
}
85110

csrc/core/operator/general/ALiBiPE/ALiBiPE_op.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@ class ALiBiPEOp : public AsOperator {
1515
: AsOperator(op_type), batch_size_(1), seq_length_(1), num_heads_(1) {}
1616
AsStatus Init(const OperatorProto& op_proto, const DeviceContext& ctx,
1717
const TensorMap& weights_map, TensorMap* tensor_map);
18-
AsStatus Reshape() override;
19-
AsStatus Forward() override;
18+
AsStatus Reshape(RuntimeContext* runtime_ctx) override;
19+
AsStatus Forward(RuntimeContext* runtime_ctx) override;
2020

2121
private:
22+
AsStatus runContext(RuntimeContext* runtime_ctx);
23+
AsStatus runDecode(RuntimeContext* runtime_ctx);
2224
AsStatus (*kernel_launcher)(DataType dtype, void* out, int* batch_offset,
2325
int batch, int seq_len, int num_heads,
24-
int ori_num_heads, int step,
25-
const DeviceContext* ctx) = nullptr;
26+
int ori_num_heads, const DeviceContext* ctx,
27+
bool is_context,
28+
std::vector<int>& step_list) = nullptr;
2629
int batch_size_;
2730
int seq_length_;
2831
int num_heads_;
2932
int ori_num_heads_;
30-
int max_seq_;
3133
};
3234

3335
} // namespace allspark

csrc/core/operator/generate_opt/batch_mha/batch_mha_op.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,15 @@ AsStatus BatchMHAOp::runOneBatch(GenerateContext* gen_ctx, int current_batch) {
310310
if (tensor_map_->at(in_names_[1])->GetShape().Count() == 0) {
311311
mask_buf = nullptr;
312312
}
313-
void* position_embedding =
314-
pos_embedding_ ? tensor_map_->at(in_names_[2])->GetDataPtr() : nullptr;
313+
void* position_embedding = nullptr;
314+
if (pos_embedding_ == true) {
315+
const Shape& embedding_shape = tensor_map_->at(in_names_[2])->GetShape();
316+
// shape: [batch_size, 1, num_heads, step + 1]
317+
// in context phase, 'current_batch' passed by caller will always be 0
318+
position_embedding = (char*)tensor_map_->at(in_names_[2])->GetDataPtr() +
319+
current_batch * embedding_shape[2] *
320+
embedding_shape[3] * SizeofType(dtype_);
321+
}
315322
char* score_buf = (char*)(tensor_map_->at("workspace")->GetDataPtr());
316323
void** q_array = (void**)(score_buf + score_size_);
317324
void** k_array = q_array + round32(gemm_batch_);

0 commit comments

Comments
 (0)