Skip to content

Commit 1b9b010

Browse files
committed
support Qwen2, change dashinfer model extensions
- support Qwen2, add model_type Qwen_v20 - change dashinfer model extensions (asgraph, asparam -> dimodel, ditensors) - remove xxx_quantize.json config file, use command line arg instead
1 parent add989c commit 1b9b010

File tree

67 files changed

+867
-714
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+867
-714
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi
9494

9595
![Workflow and Dependency](documents/resources/image/workflow-deps.jpg?row=true)
9696

97-
1. **Model Loading and Serialization**: This procedure involves loading model weights, setting up transformation parameters, and quantization settings. Based on this information, the model is serialized and converted into the DashInfer format (.asparam, .asgraph). This functionality is accessible exclusively through a Python interface and relies on the PyTorch and transformers libraries to access the weights. The version requirements for PyTorch and transformers may vary from model to model. DashInfer itself does not impose any specific version constraints.
97+
1. **Model Loading and Serialization**: This procedure involves loading model weights, setting up transformation parameters, and quantization settings. Based on this information, the model is serialized and converted into the DashInfer format (.dimodel, .ditensors). This functionality is accessible exclusively through a Python interface and relies on the PyTorch and transformers libraries to access the weights. The version requirements for PyTorch and transformers may vary from model to model. DashInfer itself does not impose any specific version constraints.
9898

9999
2. **Model Inference**: This step is responsible for executing the model inference using the serialized model with DashInfer, without depending on components like PyTorch. DashInfer employs [DLPack](https://github.com/dmlc/dlpack) format tensors to facilitate interaction with external frameworks, such as PyTorch. Tensors in DLPack format can be manually created or generated through tensor conversion functions provided by deep learning frameworks. Regarding the C++ interface, since most dependencies have been statically linked, it primarily relies on the OpenMP runtime library and C++ system libraries. We applied [control over symbol exports](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/) to ensure that only DashInfer's API interface symbols are visible, thereby preventing version conflicts with existing libraries in the user's system, such as protobuf.
100100

101101
> Note:
102-
> - .asparam, .asgraph is a special model format defined by DashInfer kernel (allspark).
102+
> - .dimodel, .ditensors is a special model format defined by DashInfer kernel.
103103
> - When utilizing the Python interface, you can combine the code from steps 1 and 2. However, due to the lack of functionality for loading Huggingface models at the C++ level, the C++ interface is limited to conducting inferences with models in the DashInfer format. Therefore, it's essential to serialize the model first using the Python interface before proceeding with the C++ interface.
104104
105105
## Single-NUMA Architecture

README_CN.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1-
# 简介
1+
<div align="center">
2+
3+
[![PyPI](https://img.shields.io/pypi/v/dashinfer)](https://pypi.org/project/dashinfer/)
4+
<!-- [![Documentation Status](https://readthedocs.org/projects/easy-cv/badge/?version=latest)](https://easy-cv.readthedocs.io/en/latest/) -->
5+
6+
<h4 align="center">
7+
<p>
8+
<a href="https://github.com/modelscope/dash-infer/blob/main/README.md">English</a> |
9+
<b>中文</b>
10+
</p>
11+
</h4>
212

3-
DashInfer用于推理预训练大语言模型(LLM)的推理引擎。
13+
14+
</div>
15+
16+
# 简介
417

518
DashInfer采用C++ Runtime编写,提供C++和Python语言接口。DashInfer具有生产级别的高性能表现,适用于多种CPU架构,包括x86和ARMv9。DashInfer支持连续批处理(Continuous Batching)和多NUMA推理(NUMA-Aware),能够充分利用服务器级CPU的算力,为推理14B及以下的LLM模型提供更多的硬件选择。
619

@@ -82,12 +95,12 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
8295

8396
![Workflow and Dependency](documents/resources/image/workflow-deps.jpg?row=true)
8497

85-
1. **模型加载与序列化**:此过程负责读取模型权重、配置模型转换参数及量化参数,并根据这些信息对模型进行序列化,并生成DashInfer格式(.asparam、.asgraph)的模型。此功能仅提供Python接口,并依赖于PyTorch和transformers库来访问权重。不同模型对PyTorch和transformers的版本要求可能有所不同,DashInfer本身并没有特殊的版本要求。
98+
1. **模型加载与序列化**:此过程负责读取模型权重、配置模型转换参数及量化参数,并根据这些信息对模型进行序列化,并生成DashInfer格式(.dimodel、.ditensors)的模型。此功能仅提供Python接口,并依赖于PyTorch和transformers库来访问权重。不同模型对PyTorch和transformers的版本要求可能有所不同,DashInfer本身并没有特殊的版本要求。
8699

87100
2. **模型推理**:此步骤负责执行模型推理,使用DashInfer推理序列化后的模型,不依赖PyTorch等组件。DashInfer采用[DLPack](https://github.com/dmlc/dlpack)格式的tensor来实现与外部框架(如PyTorch)的交互。DLPack格式的tensor,可以通过手动创建或由深度学习框架的tensor转换函数产生。对于C++接口,由于已经将几乎所有依赖静态编译,仅对openmp运行时库以及C++系统库的有依赖。我们进行了[链接符号处理](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/),以确保只有DashInfer的API接口符号可见,避免与客户系统中已有的公共库(如protobuf等)发生版本冲突。
88101

89102
> 说明:
90-
> - .asparam、.asgraph是由DashInfer内核(allspark)定义的一种特殊的模型格式
103+
> - .dimodel、.ditensors是由DashInfer内核定义的一种特殊的模型格式
91104
> - 使用Python接口时,可以将步骤1和2的代码放在一起。由于缺少C++层面加载Huggingface模型的功能,C++接口只能进行DashInfer格式的模型推理,因此在使用C++接口前,必须先用Python接口先对模型进行序列化。
92105
93106
## 单NUMA架构图

csrc/common/as_engine.cpp

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,6 @@ AsStatus AsEngineImpl::SetNumThreads(int num_threads) {
245245
DLOG(INFO) << "AsEngineImpl::SetNumThreads()" << std::endl;
246246
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
247247
device_ctx_->SetNumThreads(num_threads);
248-
if (nranks_ > threadpool_size_) {
249-
threadpool_size_ = nranks_ * 2;
250-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
251-
}
252248
std::future<AsStatus> result[nranks_];
253249
for (int i = 0; i < workers_.size(); ++i) {
254250
result[i] = threadpool_->enqueue([this, i, &num_threads]() {
@@ -561,10 +557,6 @@ AsStatus AsEngineImpl::UnloadModelFromDeviceMemory(const char* model_name) {
561557
DLOG(INFO) << "[" << model_name << "] "
562558
<< "AsEngineImpl::UnloadModelFromDeviceMemory()" << std::endl;
563559
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
564-
if (nranks_ > threadpool_size_) {
565-
threadpool_size_ = nranks_ * 2;
566-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
567-
}
568560
std::future<AsStatus> result[nranks_];
569561
for (int i = 0; i < nranks_; ++i) {
570562
result[i] = threadpool_->enqueue(
@@ -635,10 +627,6 @@ AsStatus AsEngineImpl::StartModel(const char* model_name, bool do_warmup) {
635627
int64_t min_bytes_available = std::numeric_limits<int64_t>::max();
636628
int64_t rank_0_bytes_available{0};
637629
if (use_adaptive_cache_) {
638-
if (nranks_ > threadpool_size_) {
639-
threadpool_size_ = nranks_ * 2;
640-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
641-
}
642630
std::future<int64_t> result[nranks_];
643631
for (int i = 0; i < nranks_; ++i) {
644632
result[i] = threadpool_->enqueue([this, i]() -> int64_t {
@@ -850,6 +838,7 @@ AsStatus AsEngineImpl::StopModel(const char* model_name) {
850838
model_state->cond_var->notify_all();
851839

852840
auto ret = reply_promise->get_future().get();
841+
model_state->model_stopping = true;
853842

854843
if (ret != AsStatus::ALLSPARK_SUCCESS) {
855844
LOG(ERROR) << "[" << model_name << "] "
@@ -885,10 +874,6 @@ AsStatus AsEngineImpl::ReloadModelFromDeviceMemory(const char* model_name) {
885874
}
886875

887876
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
888-
if (nranks_ > threadpool_size_) {
889-
threadpool_size_ = nranks_ * 2;
890-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
891-
}
892877
std::future<AsStatus> result[nranks_];
893878
for (int i = 0; i < nranks_; ++i) {
894879
result[i] = threadpool_->enqueue([this, i, &model_ir]() {
@@ -1159,10 +1144,6 @@ AsStatus AsEngineImpl::RunTextGenerationContinue(const char* model_name) {
11591144
return AsStatus::ALLSPARK_INVALID_CALL_ERROR;
11601145
}
11611146
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
1162-
if (nranks_ > threadpool_size_) {
1163-
threadpool_size_ = nranks_ * 2;
1164-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
1165-
}
11661147
std::future<AsStatus> result[nranks_];
11671148
for (int i = 0; i < nranks_; ++i) {
11681149
result[i] = threadpool_->enqueue([this, i]() {
@@ -1236,10 +1217,6 @@ AsStatus AsEngineImpl::RunTextGenerationContext(const char* model_name) {
12361217
return AsStatus::ALLSPARK_INVALID_CALL_ERROR;
12371218
}
12381219
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
1239-
if (nranks_ > threadpool_size_) {
1240-
threadpool_size_ = nranks_ * 2;
1241-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
1242-
}
12431220
std::future<AsStatus> result[nranks_];
12441221
for (int i = 0; i < nranks_; ++i) {
12451222
result[i] = threadpool_->enqueue([this, i]() {
@@ -1293,18 +1270,15 @@ AsStatus AsEngineImpl::StopRequestByRequestID(const char* model_name,
12931270
LOG(ERROR) << "Invalid model name : " << model_name << std::endl;
12941271
return AsStatus::ALLSPARK_PARAM_ERROR;
12951272
}
1296-
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
1297-
if (nranks_ > threadpool_size_) {
1298-
threadpool_size_ = nranks_ * 2;
1299-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
1300-
}
1273+
13011274
std::future<AsStatus> result[nranks_];
13021275
for (int i = 0; i < nranks_; ++i) {
13031276
result[i] = threadpool_->enqueue([this, i, request_id]() {
13041277
return workers_[i]->StopRequest(request_id);
13051278
});
13061279
}
13071280

1281+
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
13081282
// 即使失败、异常,也要让各子线程运行完毕,以保证原子性。在可恢复的情况下,确保下一次请求有干净的环境
13091283
AsStatus failed_ret = AsStatus::ALLSPARK_SUCCESS;
13101284
for (int i = 0; i < nranks_; ++i) {
@@ -1333,18 +1307,15 @@ AsStatus AsEngineImpl::ReleaseRequestByRequestID(const char* model_name,
13331307
LOG(ERROR) << "Invalid model name : " << model_name << std::endl;
13341308
return AsStatus::ALLSPARK_PARAM_ERROR;
13351309
}
1336-
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
1337-
if (nranks_ > threadpool_size_) {
1338-
threadpool_size_ = nranks_ * 2;
1339-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
1340-
}
1310+
13411311
std::future<AsStatus> result[nranks_];
13421312
for (int i = 0; i < nranks_; ++i) {
13431313
result[i] = threadpool_->enqueue([this, i, request_id]() {
13441314
return workers_[i]->ReleaseRequest(request_id);
13451315
});
13461316
}
13471317

1318+
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
13481319
// 即使失败、异常,也要让各子线程运行完毕,以保证原子性。在可恢复的情况下,确保下一次请求有干净的环境
13491320
AsStatus failed_ret = AsStatus::ALLSPARK_SUCCESS;
13501321
for (int i = 0; i < nranks_; ++i) {
@@ -1461,10 +1432,6 @@ AsStatus AsEngineImpl::StartRequestImpl(
14611432
{out_name, std::make_shared<AsTensor>(out_name, DeviceType::CPU,
14621433
DataType::INT64, DataMode::DENSE,
14631434
Shape{1, engine_max_length_})});
1464-
if (nranks_ > threadpool_size_) {
1465-
threadpool_size_ = nranks_ * 2;
1466-
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
1467-
}
14681435
std::future<AsStatus> result[nranks_];
14691436
for (int i = 0; i < nranks_; ++i) {
14701437
result[i] = threadpool_->enqueue(
@@ -1668,7 +1635,6 @@ void AsEngineImpl::ModelRunningThread(
16681635
s.c_str()); // set the name (pthread_self() returns the
16691636
// pthread_t of the current thread)
16701637
bool looping = true;
1671-
long loop_cnt = 0;
16721638
bool graceful_stop_phase = false;
16731639
bool graceful_final_released = false;
16741640
std::unique_ptr<EngineControlMessage> graceful_stop_msg = nullptr;
@@ -1682,7 +1648,6 @@ void AsEngineImpl::ModelRunningThread(
16821648

16831649
while (looping) {
16841650
util::Timer time_outer;
1685-
loop_cnt++;
16861651
UpdateAsEngineStat();
16871652
// print the engine state for easier service trace.
16881653
// for multiple numa, only print this info on node 0.
@@ -1957,6 +1922,7 @@ void AsEngineImpl::ModelRunningThread(
19571922
if (graceful_final_released) {
19581923
assert(graceful_stop_msg != nullptr);
19591924
graceful_stop_msg->promise->set_value(AsStatus::ALLSPARK_SUCCESS);
1925+
model_state->model_stopped = true;
19601926
DLOG(INFO) << "All done, gracefully stopped!";
19611927
break;
19621928
}
@@ -2144,4 +2110,23 @@ std::string AsEngineStat::ToString() const {
21442110
return result;
21452111
}
21462112

2113+
std::map<std::string, std::string> AsEngineStat::ToMap() const {
2114+
std::map<std::string, std::string> engine_stat_map;
2115+
engine_stat_map["free_token"] = std::to_string(free_token);
2116+
engine_stat_map["total_token"] = std::to_string(total_token);
2117+
engine_stat_map["pendding_request"] = std::to_string(pendding_request);
2118+
engine_stat_map["running_request"] = std::to_string(running_request);
2119+
engine_stat_map["total_device_memory_pool_size"] =
2120+
std::to_string(total_device_memory_pool_size);
2121+
engine_stat_map["used_device_memory_pool_size"] =
2122+
std::to_string(used_device_memory_pool_size);
2123+
engine_stat_map["total_generated_token"] =
2124+
std::to_string(total_generated_token);
2125+
engine_stat_map["total_prefill_token"] = std::to_string(total_prefill_token);
2126+
engine_stat_map["generate_token_persec"] =
2127+
std::to_string(generate_token_persec);
2128+
engine_stat_map["process_token_persec"] =
2129+
std::to_string(process_token_persec);
2130+
return engine_stat_map;
2131+
}
21472132
} // namespace allspark

csrc/common/engine_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ModelControlState final {
3535
result_queue_map;
3636

3737
bool model_stopping = false; // after GracefulStopModel called...
38+
bool model_stopped = false; // after GracefulStopModel is done.
3839

3940
ModelControlState(const std::string& name) : model_name(name) {
4041
lock = std::make_unique<std::mutex>();

csrc/core/model/qwen/qwen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ AsStatus QwenModel::Init(const TransformerProto& model_proto,
2424
REGISTER_MODEL("Qwen", QwenModel)
2525
REGISTER_MODEL("Qwen_v10", QwenModel_v10)
2626
REGISTER_MODEL("Qwen_v15", QwenModel_v15)
27+
REGISTER_MODEL("Qwen_v20", QwenModel_v20)
2728
} // namespace allspark

csrc/core/model/qwen/qwen.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,10 @@ class QwenModel_v15 : public QwenModel {
3131
: QwenModel(model_type){};
3232
};
3333

34+
class QwenModel_v20 : public QwenModel {
35+
public:
36+
explicit QwenModel_v20(const std::string& model_type = "")
37+
: QwenModel(model_type){};
38+
};
39+
3440
} // namespace allspark

csrc/core/operator/general/get_last_line/get_last_line.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace allspark {
1414
class GetLastLineOp : public AsOperator {
1515
public:
1616
explicit GetLastLineOp(const std::string& op_type = "")
17-
: AsOperator(op_type), batch_(0), seq_(0), hid_(0) {}
17+
: AsOperator(op_type) {}
1818
AsStatus Init(const OperatorProto& op_proto, const DeviceContext& ctx,
1919
const TensorMap& weights_map, TensorMap* tensor_map);
2020
AsStatus Reshape() override;

csrc/core/operator/general/rotary/rotary_op.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,33 @@ AsStatus RotaryOp::Init(const OperatorProto& op_proto, const DeviceContext& ctx,
9898
if (attr_map.find("rotary_base") != attr_map.end()) {
9999
base_ = *(float*)(attr_map.at("rotary_base").c_str());
100100
}
101-
xlogn_ = -1;
102-
if (attr_map.find("logn_model_embedding") != attr_map.end()) {
103-
xlogn_ = *(int*)(attr_map.at("logn_model_embedding").c_str());
101+
102+
// attr.8 multi_query_group_num
103+
if (attr_map.find("multi_query_group_num") != attr_map.end()) {
104+
group_num_ = *(int*)(attr_map.at("multi_query_group_num").c_str());
105+
} else {
106+
group_num_ = num_heads_;
104107
}
108+
size_per_head_ = ctx.GetSizePerHead();
105109

106110
// backend switch
107111
DeviceType backend = ctx.GetDeviceType();
108112
switch (backend) {
109113
case DeviceType::CPU: {
110114
const CPUContext* cpu_ctx = static_cast<const CPUContext*>(ctx_);
111115
num_heads_ /= cpu_ctx->GetNranks();
116+
if (group_num_ != 1) {
117+
group_num_ /= cpu_ctx->GetNranks();
118+
}
112119
break;
113120
}
114121
default:
115122
LOG(ERROR) << op_type_ << " Operator does not support "
116123
<< DeviceType_Name(backend) << " device type" << std::endl;
117124
return AsStatus::ALLSPARK_RUNTIME_ERROR;
118125
}
126+
kv_stride_ = size_per_head_ * group_num_;
127+
hidden_size_ = size_per_head_ * num_heads_;
119128
return AsStatus::ALLSPARK_SUCCESS;
120129
}
121130
AsStatus RotaryOp::Reshape(RuntimeContext* runtime_ctx) {
@@ -124,42 +133,40 @@ AsStatus RotaryOp::Reshape(RuntimeContext* runtime_ctx) {
124133
Shape y_shape(x_shape);
125134
batch_size_ = y_shape[0];
126135
seq_len_ = y_shape[1];
127-
hidden_size_ = y_shape[2] / 3;
128-
tensor_map_->at(out_names_[0])->SetShape(std::move(y_shape));
129-
// set variable
130-
if (hidden_size_ % num_heads_) {
131-
LOG(ERROR) << "Invalid attribute in RotaryOp. hidden_size : "
132-
<< hidden_size_ << ", num_heads : " << num_heads_ << std::endl;
136+
qkv_stride_ = y_shape[2];
137+
if (qkv_stride_ != hidden_size_ + 2 * kv_stride_) {
138+
LOG(ERROR) << "Invalid qkv_stride_ in RotaryOp"
139+
<< ", qkv_strde = " << qkv_stride_
140+
<< ", hidden_size = " << hidden_size_
141+
<< ", kv_stride = " << kv_stride_ << std::endl;
133142
return AsStatus::ALLSPARK_RUNTIME_ERROR;
134143
}
135-
size_per_head_ = hidden_size_ / num_heads_;
136-
gemm_batch_ = batch_size_ * num_heads_;
144+
tensor_map_->at(out_names_[0])->SetShape(std::move(y_shape));
137145
return AsStatus::ALLSPARK_SUCCESS;
138146
}
139147
AsStatus RotaryOp::RunRotary(int run_batch_size, AsTensor* rotary_step,
140148
AsTensor* rotary_inv_freq) {
141149
int* run_step = (int*)rotary_step->GetDataPtr();
142150
float* inv_freq = (float*)rotary_inv_freq->GetDataPtr();
143-
int qkv_stride = 3 * hidden_size_;
151+
int qkv_stride = qkv_stride_;
144152
int* batch_offset = nullptr;
145-
int offset = hidden_size_ * SizeofType(dtype_);
146153
void* q_buf = (char*)tensor_map_->at(in_names_[0])->GetDataPtr();
147-
void* k_buf = (char*)q_buf + offset;
148-
void* v_buf = (char*)k_buf + offset;
154+
void* k_buf = (char*)q_buf + hidden_size_ * SizeofType(dtype_);
155+
void* v_buf = (char*)k_buf + kv_stride_ * SizeofType(dtype_);
149156
void* outq_buf = (char*)tensor_map_->at(out_names_[0])->GetDataPtr();
150-
void* outk_buf = (char*)outq_buf + offset;
151-
void* outv_buf = (char*)outk_buf + offset;
157+
void* outk_buf = (char*)outq_buf + hidden_size_ * SizeofType(dtype_);
158+
void* outv_buf = (char*)outk_buf + kv_stride_ * SizeofType(dtype_);
152159

153160
rotary_launcher(dtype_, outq_buf, q_buf, inv_freq, batch_offset,
154161
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
155162
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_,
156163
xlogn_, ctx_);
157164
rotary_launcher(dtype_, outk_buf, k_buf, inv_freq, batch_offset,
158-
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
165+
run_batch_size, seq_len_, run_step, hidden_size_, group_num_,
159166
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_, -1,
160167
ctx_);
161168
rotary_launcher(dtype_, outv_buf, v_buf, nullptr, batch_offset,
162-
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
169+
run_batch_size, seq_len_, run_step, hidden_size_, group_num_,
163170
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_, -1,
164171
ctx_);
165172
return AsStatus::ALLSPARK_SUCCESS;

csrc/core/operator/general/rotary/rotary_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ class RotaryOp : public AsOperator {
8282
float rotary_pct_;
8383
float seqlen_extrapolation_;
8484
int ntk_model_embed_;
85+
int group_num_ = 0;
86+
int qkv_stride_ = 0;
87+
int kv_stride_ = 0;
8588
};
8689

8790
} // namespace allspark

0 commit comments

Comments
 (0)