Skip to content

Commit 012eb1b

Browse files
authored
update sampling, prefix cache, json mode impl (#55)
- engine: stop and release model when engine release, and remove deprecated lock - sampling: generate_op heavily modified, remove dependency on global tensors - prefix cache: some bug fix, impove evict performance - json mode: update lmfe-cpp patch, add process_logits, sampling with top_k top_p - span-attention: move span_attn decoderReshape to init - lora: add docs, fix typo - ubuntu: add ubuntu dockerfile, fix install dir err - bugifx: fix multi-batch rep_penlty bug
1 parent 0d65c04 commit 012eb1b

File tree

80 files changed

+1840
-2401
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1840
-2401
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,4 @@ third_party/from_source/openssl/*
2525
*.nsys-rep
2626
log*
2727
*.csv
28-
#*.sh
2928
*.as*

cmake/hie-dnn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ message(STATUS "Build HIE-DNN with: USE_CUDA=${HIEDNN_USE_CUDA}")
2525
message(STATUS "Build HIE-DNN with: CUDA_DEVICE_ARCH=${HIEDNN_CUDA_DEVICE_ARCH}")
2626

2727
set(HIEDNN_INSTALL ${INSTALL_LOCATION}/HIE-DNN/install)
28-
set(HIEDNN_LIBRARY_PATH ${HIEDNN_INSTALL}/lib64/libhiednn_static.a)
28+
set(HIEDNN_LIBRARY_PATH ${HIEDNN_INSTALL}/${CMAKE_INSTALL_LIBDIR}/libhiednn_static.a)
2929
message(STATUS "HIEDNN_INSTALL: ${HIEDNN_INSTALL}")
3030
message(STATUS "HIEDNN_LIBRARY_PATH: ${HIEDNN_LIBRARY_PATH}")
3131

csrc/common/as_engine.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ class AsEngineImpl final {
246246

247247
ModelControlState::ModelControlState(const std::string& name)
248248
: model_name(name), msg_queue(1000) {
249-
cond_var = std::make_unique<std::condition_variable>();
250249
request_handle_map.reserve(1000);
251250
result_queue_map.reserve(1000);
252251
msg_queue_size.store(0);
@@ -328,14 +327,25 @@ AsEngineImpl::~AsEngineImpl() {
328327
// we should wait for the running thread stop otherwise it will cause
329328
// exception.
330329
LOG(INFO) << "~AsEngine called";
331-
for (auto& model_state : model_state_map_) {
332-
if (!model_state.second && model_state.second->model_stopped) {
333-
LOG(INFO) << "Stopping model " << model_state.first;
334-
StopModel(model_state.first.c_str());
335-
ReleaseModel(model_state.first.c_str());
336-
// model_state.second->StopLoop();
330+
331+
std::vector<std::string> pending_stop_model;
332+
333+
{
334+
std::lock_guard<std::mutex> guard(engine_lock_);
335+
LOG(INFO) << "model_state_map_ size:" << model_state_map_.size();
336+
for (auto& model_state : model_state_map_) {
337+
if (!model_state.second->model_stopped) {
338+
LOG(INFO) << "Stopping model " << model_state.first;
339+
pending_stop_model.push_back(model_state.first);
340+
}
337341
}
338342
}
343+
344+
for (auto& name : pending_stop_model) {
345+
StopModel(name.c_str());
346+
ReleaseModel(name.c_str());
347+
}
348+
339349
// LOG(INFO) << "~Engine clear BFC Allocator ";
340350
// free weight manager before destroy bfc.
341351
bool do_destroy_bfc = weight_manager_->GetNumModels() > 0;
@@ -1557,9 +1567,10 @@ AsStatus AsEngineImpl::StopModel(const char* model_name) {
15571567
LOG(ERROR) << "push message queue failed.";
15581568
}
15591569
}
1560-
model_state->cond_var->notify_all();
15611570

1571+
LOG(INFO) << "AsEngineImpl:: wait stop model return";
15621572
auto ret = reply_promise->get_future().get();
1573+
LOG(INFO) << "AsEngineImpl:: stop model got return.";
15631574

15641575
if (ret != AsStatus::ALLSPARK_SUCCESS) {
15651576
LOG(ERROR) << "[" << model_name << "] "
@@ -1781,8 +1792,7 @@ AsStatus AsEngineImpl::StartRequest(
17811792
handle->mm_embedding_internal = extra_embedding;
17821793

17831794
#ifdef ENABLE_JSON_MODE
1784-
if (request_info->config.response_format.find("type") !=
1785-
request_info->config.response_format.end() &&
1795+
if (request_info->config.response_format.count("type") &&
17861796
request_info->config.response_format["type"] == "json_object") {
17871797
if (util::FormatEnforcer::vocab_.empty() &&
17881798
request_info->config.vocab.empty()) {
@@ -1824,7 +1834,6 @@ AsStatus AsEngineImpl::StartRequest(
18241834
// create result queue & handle
18251835
}
18261836

1827-
model_state->cond_var->notify_one();
18281837
#ifndef ENABLE_CUDA
18291838
workers_[0]->GetDeviceContext()->SemWaitSendInterProcess();
18301839
#endif
@@ -1869,7 +1878,6 @@ AsStatus AsEngineImpl::StopRequest(const char* model_name,
18691878
reply_promise, uuid);
18701879
model_state->msg_queue.enqueue(std::move(msg));
18711880
}
1872-
model_state->cond_var->notify_one();
18731881
#ifndef ENABLE_CUDA
18741882
workers_[0]->GetDeviceContext()->SemWaitSendInterProcess();
18751883
#endif
@@ -1916,7 +1924,6 @@ AsStatus AsEngineImpl::ReleaseRequest(const char* model_name,
19161924
model_state->msg_queue.enqueue(std::move(msg));
19171925
}
19181926

1919-
model_state->cond_var->notify_one();
19201927
auto ret = reply_promise->get_future().get();
19211928
if (ret == AsStatus::ALLSPARK_SUCCESS) {
19221929
LOG(INFO) << "[" << model_name << "] "
@@ -1965,7 +1972,6 @@ AsStatus AsEngineImpl::SyncRequest(const char* model_name,
19651972
reply_promise, uuid);
19661973
model_state->msg_queue.enqueue(std::move(msg));
19671974
}
1968-
model_state->cond_var->notify_one();
19691975
#ifndef ENABLE_CUDA
19701976
workers_[0]->GetDeviceContext()->SemWaitSendInterProcess();
19711977
#endif
@@ -2430,6 +2436,7 @@ AsStatus AsEngineImpl::InputParamsVerify(
24302436
<< "gen_cfg.top_p must in [0,1]" << std::endl;
24312437
return AsStatus::ALLSPARK_PARAM_ERROR;
24322438
}
2439+
24332440
if (gen_cfg.temperature < SAMPLING_EPS) {
24342441
DLOG(INFO) << "[" << model_name << "] "
24352442
<< "gen_cfg.temperature = " << gen_cfg.temperature
@@ -2438,6 +2445,14 @@ AsStatus AsEngineImpl::InputParamsVerify(
24382445
gen_cfg.top_p = 0;
24392446
gen_cfg.temperature = 1.0;
24402447
}
2448+
2449+
if (std::abs(gen_cfg.top_p - 1.0) < 1e-6) {
2450+
LOG(WARNING) << "[" << model_name << "] "
2451+
<< "gen_cfg.top_p == 1.0, This might lead to performance "
2452+
"issues, so it is manually set to 0.99. "
2453+
<< std::endl;
2454+
gen_cfg.top_p = 0.99;
2455+
}
24412456
// user customized max batch size
24422457
if (engine_max_batch_ != 0 && input_batch > engine_max_batch_) {
24432458
LOG(ERROR) << "[" << model_name << "] "
@@ -2481,8 +2496,8 @@ AsStatus AsEngineImpl::StartRequestImpl(
24812496
<< std::endl;
24822497

24832498
TensorMap out_tensors;
2484-
// TODO: alloc generated_ids on CPU
2485-
std::string out_name = "generated_ids";
2499+
// TODO: alloc generated_ids_global on CPU
2500+
std::string out_name = "generated_ids_global";
24862501
out_tensors.insert(
24872502
{out_name, std::make_shared<AsTensor>(out_name, DeviceType::CPU,
24882503
DataType::INT64, DataMode::DENSE,
@@ -2531,11 +2546,11 @@ FetchGenerationResultAndIncreaseCounter(
25312546
ele->prefix_len_gpu = request->prefix_len_gpu;
25322547
ele->prefix_len_cpu = request->prefix_len - request->prefix_len_gpu;
25332548

2534-
TensorMap& tmap = request->outputs;
2549+
const TensorMap& tmap = request->outputs;
25352550

25362551
std::vector<std::vector<std::pair<int, float>>> log_probs_list =
25372552
request->log_probs_list;
2538-
auto device_tensor_ptr = tmap.at("generated_ids");
2553+
auto device_tensor_ptr = tmap.at("generated_ids_global");
25392554
if (device_tensor_ptr->GetShape().Count() == 0) {
25402555
return nullptr;
25412556
}

csrc/common/common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,24 @@ inline std::ostream& operator<<(std::ostream& os, DeviceType device_type) {
189189
}
190190
}
191191

192+
inline int get_layer_num(std::string str) {
193+
std::stringstream ss(str);
194+
std::string temp;
195+
while (std::getline(ss, temp, '.')) {
196+
bool flag = true;
197+
for (char c : temp) {
198+
if (!std::isdigit(c)) /* 如果不是数字,返回 false */ {
199+
flag = false;
200+
break;
201+
}
202+
}
203+
if (flag) {
204+
return std::stoi(temp);
205+
}
206+
}
207+
return -1;
208+
}
209+
192210
// deprecated api declear
193211
#if __cplusplus >= 201402L // c++14
194212

csrc/common/engine_runtime.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,10 @@ class ModelControlState final {
4848
moodycamel::BlockingConcurrentQueue<EngineControlMessage> msg_queue;
4949
std::atomic<int> msg_queue_size;
5050

51-
std::unique_ptr<std::condition_variable> cond_var;
52-
5351
std::unordered_map<std::string, std::shared_ptr<RequestHandle>>
5452
request_handle_map;
5553
std::unordered_map<std::string, std::shared_ptr<AsEngine::ResultQueue>>
5654
result_queue_map;
57-
std::queue<std::shared_ptr<RequestHandle>> release_request_handle;
58-
std::queue<std::shared_ptr<AsEngine::ResultQueue>> release_request_queue;
5955
std::atomic<bool> model_stopping =
6056
false; // after GracefulStopModel called...
6157
std::atomic<bool> model_stopped = false; // after GracefulStopModel is done.

csrc/common/extra_embedding.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,17 @@ class ExtraEmbeddingUtils {
145145
}
146146

147147
static AsStatus CreateTensorForHash(std::shared_ptr<Request> request,
148-
TensorMap& tensor_map,
148+
TensorMap& dst_tensor_map,
149+
const TensorMap& src_tensor_map,
149150
std::string src_tensor_name) {
150151
std::string dst_tensor_name = src_tensor_name + "_for_hash";
151152

152153
if (!request->extra_embedding.empty()) {
153154
if (request->extra_embedding.count("hash_input") > 0) {
154155
// step 1: parse extra_embedding info
155156
int64_t* tensor_ptr =
156-
(int64_t*)tensor_map[src_tensor_name]->GetDataPtr();
157-
int seq_len = tensor_map[src_tensor_name]->GetShape()[1];
157+
(int64_t*)src_tensor_map.at(src_tensor_name)->GetDataPtr();
158+
int seq_len = src_tensor_map.at(src_tensor_name)->GetShape()[1];
158159
auto reinfo_vec = std::make_shared<ExtraEmbeddingUtils::REInfoList>();
159160
AS_CHECK_STATUS(ExtraEmbeddingUtils::ParseExtraEmbedding(
160161
request->extra_embedding, tensor_ptr, seq_len, reinfo_vec));
@@ -165,20 +166,22 @@ class ExtraEmbeddingUtils {
165166

166167
// step 3: create a new input tensor
167168
auto dst_tensor = std::make_shared<AsTensor>(
168-
dst_tensor_name, *tensor_map[src_tensor_name]);
169+
dst_tensor_name, *src_tensor_map.at(src_tensor_name));
169170

170171
// step 4: replace place holder with hashes
171172
ExtraEmbeddingUtils::ReplacePlaceHolder(dst_tensor, reinfo_vec);
172173

173-
tensor_map.insert({dst_tensor_name, dst_tensor});
174+
dst_tensor_map.insert({dst_tensor_name, dst_tensor});
174175
} else {
175176
LOG(ERROR) << "multi-media content `hash_input` "
176177
<< "of request " << request->request_id << " is missing.";
177178
return AsStatus::ALLSPARK_PARAM_ERROR;
178179
}
179180
} else {
180-
// no extra embedding, use original input_ids for hash
181-
tensor_map.insert({dst_tensor_name, tensor_map[src_tensor_name]});
181+
// no extra embedding, copy original tensor for hash
182+
auto dst_tensor = std::make_shared<AsTensor>(
183+
dst_tensor_name, *src_tensor_map.at(src_tensor_name));
184+
dst_tensor_map.insert({dst_tensor_name, dst_tensor});
182185
}
183186

184187
return AsStatus::ALLSPARK_SUCCESS;

csrc/common/generate_context.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#pragma once
77

88
#if ENABLE_SPAN_ATTENTION
9+
#include <cache/prefix_cache_manager.h>
910
#include <cache/virtual_cache.h>
1011
#endif
1112
#include <common/common.h>
@@ -59,7 +60,7 @@ struct GenerateContext {
5960
#if ENABLE_SPAN_ATTENTION
6061
std::unique_ptr<VirtualCache> virtual_k_cache;
6162
std::unique_ptr<VirtualCache> virtual_v_cache;
62-
std::vector<std::string> prefix_cache_hash_list;
63+
std::vector<PrefixCacheManager::PrefixNodePtr> prefix_cache_node_list;
6364
#endif
6465

6566
#ifdef ENABLE_JSON_MODE
@@ -70,7 +71,7 @@ struct GenerateContext {
7071
std::unique_ptr<AsTensor> sample_state = nullptr;
7172
};
7273

73-
using GenContextList = std::vector<std::unique_ptr<GenerateContext>>;
74+
using GenContextList = std::vector<std::shared_ptr<GenerateContext>>;
7475
class LayerCacheManager {
7576
public:
7677
AsTensor* GetCache(std::string cache_name) {
@@ -102,15 +103,15 @@ class RuntimeContext {
102103
std::vector<float> logprobs_value_host;
103104
std::vector<float> token_logprobs_host;
104105

105-
GenerateContext* GetContextGenCtx() const {
106-
return gen_ctx_list[current_batch].get();
106+
std::shared_ptr<GenerateContext> GetContextGenCtx() const {
107+
return gen_ctx_list[current_batch];
107108
}
108-
GenerateContext* GetGenCtx(int index) const {
109-
return gen_ctx_list[index].get();
109+
std::shared_ptr<GenerateContext> GetGenCtx(int index) const {
110+
return gen_ctx_list[index];
110111
}
111112
int GetGenCtxListSize() const { return gen_ctx_list.size(); }
112-
void PushBackGenCtx(std::unique_ptr<GenerateContext> gen_ctx) {
113-
gen_ctx_list.push_back(std::move(gen_ctx));
113+
void PushBackGenCtx(std::shared_ptr<GenerateContext> gen_ctx) {
114+
gen_ctx_list.push_back(gen_ctx);
114115
gen_ctx_list[gen_ctx_list.size() - 1]->current_batch =
115116
gen_ctx_list.size() - 1;
116117
}
@@ -124,16 +125,15 @@ class RuntimeContext {
124125
gen_ctx_list[index]->current_batch = index;
125126
gen_ctx_list.pop_back();
126127
}
127-
std::shared_ptr<LayerCacheManager> CreateLayerCacheManager() {
128+
void CreateLayerCacheManager() {
128129
layer_cache_manager = std::make_shared<LayerCacheManager>();
129-
return layer_cache_manager;
130130
}
131131
std::shared_ptr<LayerCacheManager> GetLayerCacheManager() {
132132
return layer_cache_manager;
133133
}
134134

135135
private:
136-
GenContextList gen_ctx_list = std::vector<std::unique_ptr<GenerateContext>>();
136+
GenContextList gen_ctx_list = std::vector<std::shared_ptr<GenerateContext>>();
137137
std::shared_ptr<LayerCacheManager> layer_cache_manager;
138138
};
139139

csrc/common/memory_reuser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ string noreused[] = {"cross_attention.key_value.out",
3030
"batch_offset",
3131
"transmask.out",
3232
"max_dec_ids",
33-
"generated_ids",
33+
"generated_ids_global",
3434
"dec_ids",
3535
"next_beam_id",
3636
"hyps_ids",

csrc/common/request.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class FormatEnforcer;
2121
struct Request {
2222
std::string request_id;
2323
TensorMap inputs;
24-
TensorMap outputs;
24+
const TensorMap outputs; // created in AsEngineImpl, shared by all workers
25+
TensorMap interim; // intermediate tensors
2526
GenerateConfig gen_cfg;
2627
std::vector<std::vector<std::pair<int, float>>> log_probs_list;
2728
std::vector<float> token_logprobs_list;
@@ -42,31 +43,33 @@ struct Request {
4243
const std::chrono::time_point<std::chrono::steady_clock> start_ts;
4344
std::chrono::time_point<std::chrono::steady_clock> context_ts;
4445
std::chrono::time_point<std::chrono::steady_clock> generate_ts;
46+
4547
Request(const std::string& request_id_, const TensorMap& inputs_,
46-
const TensorMap& outputs_, const GenerateConfig& gen_cfg)
48+
const TensorMap& outputs_, const GenerateConfig& gen_cfg,
49+
const TensorMap& interim_ = {})
4750
: request_id(request_id_),
4851
inputs(inputs_),
4952
outputs(outputs_),
53+
interim(interim_),
5054
gen_cfg(gen_cfg),
5155
finish(false),
5256
status(AsEngine::GenerateRequestStatus::Init),
5357
start_ts(std::chrono::steady_clock::now()) {}
54-
Request(std::shared_ptr<Request> source_request) {
55-
if (source_request) {
56-
this->request_id = source_request->request_id;
57-
this->inputs = source_request->inputs;
58-
this->outputs = source_request->outputs;
59-
this->gen_cfg = source_request->gen_cfg;
60-
this->log_probs_list = source_request->log_probs_list;
61-
this->token_logprobs_list = source_request->token_logprobs_list;
62-
this->finish = source_request->finish;
63-
this->input_len = source_request->input_len;
64-
this->prefill_chunk_len = source_request->prefill_chunk_len;
65-
this->prefix_len = source_request->prefix_len;
66-
this->status = source_request->status;
67-
this->extra_embedding = source_request->extra_embedding;
68-
}
69-
}
58+
59+
Request(std::shared_ptr<Request> source_request)
60+
: request_id(source_request->request_id),
61+
inputs(source_request->inputs),
62+
outputs(source_request->outputs),
63+
interim(source_request->interim),
64+
gen_cfg(source_request->gen_cfg),
65+
log_probs_list(source_request->log_probs_list),
66+
token_logprobs_list(source_request->token_logprobs_list),
67+
finish(source_request->finish),
68+
input_len(source_request->input_len),
69+
prefill_chunk_len(source_request->prefill_chunk_len),
70+
prefix_len(source_request->prefix_len),
71+
status(source_request->status),
72+
extra_embedding(source_request->extra_embedding) {}
7073
};
7174

7275
} // namespace allspark

0 commit comments

Comments
 (0)