|
11 | 11 |
|
12 | 12 | namespace allspark { |
13 | 13 | AsStatus cpu_ALiBiPE(DataType dtype, void* out, int* batch_offset, int batch, |
14 | | - int seq_len, int num_heads, int ori_num_heads, int step, |
15 | | - const DeviceContext* ctx) { |
| 14 | + int seq_len, int num_heads, int ori_num_heads, |
| 15 | + const DeviceContext* ctx, bool is_context, |
| 16 | + std::vector<int>& step_list) { |
16 | 17 | DLOG(INFO) << "cpu_ALiBiPE" << std::endl; |
17 | 18 | const CPUContext* cpu_ctx = static_cast<const CPUContext*>(ctx); |
18 | 19 | auto functor = [&]<typename T>() { |
19 | 20 | T* typed_out = static_cast<T*>(out); |
20 | 21 | cpu::ALiBiPEKernelLauncher(typed_out, batch_offset, batch, seq_len, |
21 | | - num_heads, ori_num_heads, step, |
22 | | - cpu_ctx->GetRank()); |
| 22 | + num_heads, ori_num_heads, cpu_ctx->GetRank(), |
| 23 | + is_context, step_list); |
23 | 24 | }; |
24 | 25 | DispatchCPU(dtype, functor); |
25 | 26 | return AsStatus::ALLSPARK_SUCCESS; |
@@ -54,32 +55,56 @@ AsStatus ALiBiPEOp::Init(const OperatorProto& op_proto, |
54 | 55 | return AsStatus::ALLSPARK_SUCCESS; |
55 | 56 | } |
56 | 57 |
|
57 | | -AsStatus ALiBiPEOp::Reshape() { |
| 58 | +AsStatus ALiBiPEOp::Reshape(RuntimeContext* runtime_ctx) { |
58 | 59 | Shape in_shape = tensor_map_->at(in_names_[0])->GetShape(); |
59 | | - batch_size_ = in_shape[0]; |
60 | | - if (gen_ctx_->step == 0) { |
| 60 | + if (runtime_ctx->is_context == true) { |
| 61 | + batch_size_ = in_shape[0]; |
61 | 62 | seq_length_ = in_shape[1]; |
62 | 63 | Shape out_shape = Shape{batch_size_, seq_length_, num_heads_, seq_length_}; |
63 | 64 | AS_CHECK_STATUS( |
64 | 65 | tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape))); |
65 | | - } else { |
66 | | - seq_length_ = 1; |
67 | 66 | } |
| 67 | + return AsStatus::ALLSPARK_SUCCESS; |
| 68 | +} |
68 | 69 |
|
| 70 | +AsStatus ALiBiPEOp::runContext(RuntimeContext* runtime_ctx) { |
| 71 | + int* batch_offset = nullptr; |
| 72 | + AsTensor* out_tensor = tensor_map_->at(out_names_[0]).get(); |
| 73 | + std::vector<int> step_list; |
| 74 | + kernel_launcher(out_tensor->GetDataType(), out_tensor->GetDataPtr(), |
| 75 | + batch_offset, batch_size_, seq_length_, num_heads_, |
| 76 | + ori_num_heads_, ctx_, runtime_ctx->is_context, step_list); |
69 | 77 | return AsStatus::ALLSPARK_SUCCESS; |
70 | 78 | } |
71 | 79 |
|
72 | | -AsStatus ALiBiPEOp::Forward() { |
| 80 | +AsStatus ALiBiPEOp::runDecode(RuntimeContext* runtime_ctx) { |
73 | 81 | int* batch_offset = nullptr; |
74 | 82 | AsTensor* out_tensor = tensor_map_->at(out_names_[0]).get(); |
75 | | - if (gen_ctx_->step != 0) { |
76 | | - Shape out_shape = Shape{batch_size_, 1, num_heads_, gen_ctx_->step + 1}; |
77 | | - AS_CHECK_STATUS( |
78 | | - tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape))); |
| 83 | + int batch_size = runtime_ctx->GetGenCtxListSize(); |
| 84 | + std::vector<int> step_list(batch_size); |
| 85 | + int max_step = 1; |
| 86 | + for (int i = 0; i < batch_size; i++) { |
| 87 | + GenerateContext* gen_ctx = runtime_ctx->GetGenCtx(i); |
| 88 | + if (gen_ctx->step + 1 > max_step) { |
| 89 | + max_step = gen_ctx->step + 1; |
| 90 | + } |
| 91 | + step_list[i] = gen_ctx->step + 1; |
79 | 92 | } |
| 93 | + Shape out_shape = Shape{batch_size, 1, num_heads_, max_step}; |
| 94 | + AS_CHECK_STATUS( |
| 95 | + tensor_map_->at(out_names_[0])->SetShape(std::move(out_shape))); |
80 | 96 | kernel_launcher(out_tensor->GetDataType(), out_tensor->GetDataPtr(), |
81 | | - batch_offset, batch_size_, seq_length_, num_heads_, |
82 | | - ori_num_heads_, (gen_ctx_->step + 1), ctx_); |
| 97 | + batch_offset, batch_size, max_step, num_heads_, |
| 98 | + ori_num_heads_, ctx_, runtime_ctx->is_context, step_list); |
| 99 | + return AsStatus::ALLSPARK_SUCCESS; |
| 100 | +} |
| 101 | + |
| 102 | +AsStatus ALiBiPEOp::Forward(RuntimeContext* runtime_ctx) { |
| 103 | + if (runtime_ctx->is_context == true) { |
| 104 | + runContext(runtime_ctx); |
| 105 | + } else { |
| 106 | + runDecode(runtime_ctx); |
| 107 | + } |
83 | 108 | return AsStatus::ALLSPARK_SUCCESS; |
84 | 109 | } |
85 | 110 |
|
|
0 commit comments