Skip to content

Commit 31344ca

Browse files
author
youxiao
committed
adapt to transfer async
1 parent 2ea9115 commit 31344ca

File tree

3 files changed

+123
-22
lines changed

3 files changed

+123
-22
lines changed

mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class AscendDirectTransport : public Transport {
7979

8080
void workerThread();
8181

82+
void queryThread();
83+
8284
void processSliceList(const std::vector<Slice *> &slice_list);
8385

8486
void localCopy(TransferRequest::OpCode opcode,
@@ -116,16 +118,21 @@ class AscendDirectTransport : public Transport {
116118
std::set<std::string> connected_segments_;
117119
std::mutex connection_mutex_;
118120

119-
// Async processing related members (similar to hccl_transport)
121+
// Async processing related members
120122
std::thread worker_thread_;
121123
std::queue<std::vector<Slice *>> slice_queue_;
122124
std::mutex queue_mutex_;
123125
std::condition_variable queue_cv_;
124126

127+
std::thread query_thread_;
128+
std::queue<std::vector<Slice *>> query_slice_queue_;
129+
std::mutex query_mutex_;
130+
std::condition_variable query_cv_;
131+
125132
int32_t device_logic_id_{};
126133
aclrtContext rt_context_{nullptr};
127134
int32_t connect_timeout_ = 10000;
128-
int32_t transfer_timeout_ = 10000;
135+
int64_t transfer_timeout_ = 10000;
129136
std::string local_adxl_engine_name_{};
130137
aclrtStream stream_{};
131138
bool use_buffer_pool_{false};

mooncake-transfer-engine/include/transport/transport.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class Transport {
142142
} hccl;
143143
struct {
144144
uint64_t dest_addr;
145+
void *handle;
146+
int64_t start_time;
145147
} ascend_direct;
146148
};
147149

mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,14 @@ AscendDirectTransport::~AscendDirectTransport() {
4545
// Stop worker thread
4646
running_ = false;
4747
queue_cv_.notify_all();
48+
query_cv_.notify_all();
4849

4950
if (worker_thread_.joinable()) {
5051
worker_thread_.join();
5152
}
53+
if (query_thread_.joinable()) {
54+
query_thread_.join();
55+
}
5256

5357
// Disconnect all connections
5458
std::lock_guard<std::mutex> lock(connection_mutex_);
@@ -123,6 +127,7 @@ int AscendDirectTransport::install(std::string &local_server_name,
123127
// Start worker thread
124128
running_ = true;
125129
worker_thread_ = std::thread(&AscendDirectTransport::workerThread, this);
130+
query_thread_ = std::thread(&AscendDirectTransport::queryThread, this);
126131
return 0;
127132
}
128133

@@ -156,8 +161,8 @@ int AscendDirectTransport::InitAdxlEngine() {
156161
}
157162
}
158163
// set default buffer pool
159-
options["adxl.BufferPool"] = "4:8";
160-
use_buffer_pool_ = true;
164+
options["adxl.BufferPool"] = "0:0";
165+
use_buffer_pool_ = false;
161166
char *buffer_pool = std::getenv("ASCEND_BUFFER_POOL");
162167
if (buffer_pool) {
163168
options["adxl.BufferPool"] = buffer_pool;
@@ -192,9 +197,11 @@ int AscendDirectTransport::InitAdxlEngine() {
192197
parseFromString<int32_t>(connect_transfer_str);
193198
if (transfer_timeout.has_value()) {
194199
transfer_timeout_ = transfer_timeout.value();
195-
LOG(INFO) << "Set transfer timeout to:" << transfer_timeout_;
200+
LOG(INFO) << "Set transfer timeout to:" << transfer_timeout_
201+
<< " us.";
196202
}
197203
}
204+
transfer_timeout_ = transfer_timeout_ * 1000000;
198205
return 0;
199206
}
200207

@@ -542,6 +549,93 @@ void AscendDirectTransport::workerThread() {
542549
LOG(INFO) << "AscendDirectTransport worker thread stopped";
543550
}
544551

552+
void AscendDirectTransport::queryThread() {
553+
LOG(INFO) << "AscendDirectTransport query thread started";
554+
std::vector<std::vector<Slice *>> pending_batches;
555+
while (running_) {
556+
{
557+
std::unique_lock<std::mutex> lock(query_mutex_);
558+
if (pending_batches.empty()) {
559+
query_cv_.wait(lock, [this] {
560+
return !running_ || !query_slice_queue_.empty();
561+
});
562+
}
563+
if (!running_) {
564+
break;
565+
}
566+
while (!query_slice_queue_.empty()) {
567+
pending_batches.emplace_back(
568+
std::move(query_slice_queue_.front()));
569+
query_slice_queue_.pop();
570+
}
571+
}
572+
573+
if (pending_batches.empty()) {
574+
continue;
575+
}
576+
577+
auto it = pending_batches.begin();
578+
while (it != pending_batches.end()) {
579+
auto &slice_list = *it;
580+
if (slice_list.empty()) {
581+
it = pending_batches.erase(it);
582+
continue;
583+
}
584+
auto handle = static_cast<adxl::TransferReq>(
585+
slice_list[0]->ascend_direct.handle);
586+
adxl::TransferStatus task_status;
587+
auto ret = adxl_->GetTransferStatus(handle, task_status);
588+
if (ret != adxl::SUCCESS ||
589+
task_status == adxl::TransferStatus::FAILED) {
590+
LOG(ERROR) << "Get transfer status failed, ret: " << ret;
591+
for (auto &slice : slice_list) {
592+
slice->markFailed();
593+
}
594+
it = pending_batches.erase(it);
595+
} else if (task_status == adxl::TransferStatus::COMPLETED) {
596+
auto now = getCurrentTimeInNano();
597+
auto duration = now - slice_list[0]->ascend_direct.start_time;
598+
auto target_segment_desc =
599+
metadata_->getSegmentDescByID(slice_list[0]->target_id);
600+
if (target_segment_desc) {
601+
auto target_adxl_engine_name =
602+
(target_segment_desc->rank_info.hostIp + ":" +
603+
std::to_string(
604+
target_segment_desc->rank_info.hostPort));
605+
LOG(INFO) << "Transfer to " << target_adxl_engine_name
606+
<< " time: " << duration / 1000 << "us";
607+
}
608+
for (auto &slice : slice_list) {
609+
slice->markSuccess();
610+
}
611+
it = pending_batches.erase(it);
612+
} else {
613+
auto now = getCurrentTimeInNano();
614+
if (now - slice_list[0]->ascend_direct.start_time >
615+
transfer_timeout_) {
616+
LOG(ERROR)
617+
<< "Transfer timeout, you can increase the timeout "
618+
"duration to reduce "
619+
"the failure rate by configuring "
620+
"the ASCEND_TRANSFER_TIMEOUT environment variable.";
621+
for (auto &slice : slice_list) {
622+
slice->markFailed();
623+
}
624+
it = pending_batches.erase(it);
625+
} else {
626+
++it;
627+
}
628+
}
629+
}
630+
631+
if (!pending_batches.empty()) {
632+
// Avoid busy loop
633+
std::this_thread::sleep_for(std::chrono::microseconds(10));
634+
}
635+
}
636+
LOG(INFO) << "AscendDirectTransport query thread stopped";
637+
}
638+
545639
void AscendDirectTransport::processSliceList(
546640
const std::vector<Slice *> &slice_list) {
547641
if (slice_list.empty()) {
@@ -591,7 +685,6 @@ void AscendDirectTransport::processSliceList(
591685
}
592686
return;
593687
}
594-
auto start = std::chrono::steady_clock::now();
595688
std::vector<adxl::TransferOpDesc> op_descs;
596689
op_descs.reserve(slice_list.size());
597690
for (auto &slice : slice_list) {
@@ -602,26 +695,25 @@ void AscendDirectTransport::processSliceList(
602695
op_desc.len = slice->length;
603696
op_descs.emplace_back(op_desc);
604697
}
605-
auto status = adxl_->TransferSync(target_adxl_engine_name.c_str(),
606-
operation, op_descs, transfer_timeout_);
698+
auto start_time = getCurrentTimeInNano();
699+
for (auto &slice : slice_list) {
700+
slice->ascend_direct.start_time = start_time;
701+
}
702+
adxl::TransferReq req_handle;
703+
auto status =
704+
adxl_->TransferAsync(target_adxl_engine_name.c_str(), operation,
705+
op_descs, adxl::TransferArgs(), req_handle);
607706
if (status == adxl::SUCCESS) {
608707
for (auto &slice : slice_list) {
609-
slice->markSuccess();
708+
slice->ascend_direct.handle = req_handle;
610709
}
611-
LOG(INFO) << "Transfer to:" << target_adxl_engine_name << ", cost: "
612-
<< std::chrono::duration_cast<std::chrono::microseconds>(
613-
std::chrono::steady_clock::now() - start)
614-
.count()
615-
<< " us";
616-
} else {
617-
if (status == adxl::TIMEOUT) {
618-
LOG(ERROR) << "Transfer timeout to: " << target_adxl_engine_name
619-
<< ", you can increase the timeout duration to reduce "
620-
"the failure rate by configuring "
621-
"the ASCEND_TRANSFER_TIMEOUT environment variable.";
622-
} else {
623-
LOG(ERROR) << "Transfer slice failed with status: " << status;
710+
{
711+
std::unique_lock<std::mutex> lock(query_mutex_);
712+
query_slice_queue_.push(slice_list);
624713
}
714+
query_cv_.notify_one();
715+
} else {
716+
LOG(ERROR) << "Transfer slice failed with status: " << status;
625717
for (auto &slice : slice_list) {
626718
slice->markFailed();
627719
}

0 commit comments

Comments
 (0)