Skip to content

Commit f5780b4

Browse files
author
youxiao
committed
ascend enable fabric mem
1 parent 317384c commit f5780b4

File tree

5 files changed

+80
-20
lines changed

5 files changed

+80
-20
lines changed

mooncake-store/src/client_buffer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ std::shared_ptr<ClientBufferAllocator> ClientBufferAllocator::create(
2727
ClientBufferAllocator::ClientBufferAllocator(size_t size,
2828
const std::string& protocol)
2929
: protocol(protocol), buffer_size_(size) {
30+
if (size == 0) {
31+
return;
32+
}
3033
// Align to 64 bytes(cache line size) for better cache performance
3134
constexpr size_t alignment = 64;
3235
buffer_ = allocate_buffer_allocator_memory(size, protocol, alignment);
@@ -55,6 +58,9 @@ ClientBufferAllocator::~ClientBufferAllocator() {
5558
}
5659

5760
std::optional<BufferHandle> ClientBufferAllocator::allocate(size_t size) {
61+
if (allocator_ == nullptr) {
62+
return std::nullopt;
63+
}
5864
auto handle = allocator_->allocate(size);
5965
if (!handle) {
6066
return std::nullopt;

mooncake-store/src/client_service.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ ErrorCode Client::InitTransferEngine(
261261
}
262262
}
263263

264+
if (protocol == "ascend") {
265+
const char* ascend_use_fabric_mem =
266+
std::getenv("ASCEND_ENABLE_USE_FABRIC_MEM");
267+
if (ascend_use_fabric_mem) {
268+
globalConfig().ascend_use_fabric_mem = true;
269+
}
270+
}
264271
auto [hostname, port] = parseHostNameWithPort(local_hostname);
265272
int rc = transfer_engine_->init(metadata_connstring, local_hostname,
266273
hostname, port);

mooncake-store/src/utils.cpp

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <random>
1111
#ifdef USE_ASCEND_DIRECT
1212
#include "acl/acl.h"
13+
#include "config.h"
1314
#endif
1415

1516
#include <ylt/coro_http/coro_http_client.hpp>
@@ -79,13 +80,41 @@ void *allocate_buffer_allocator_memory(size_t total_size,
7980
}
8081
#ifdef USE_ASCEND_DIRECT
8182
if (protocol == "ascend" && total_size > 0) {
82-
void *buffer = nullptr;
83-
auto ret = aclrtMallocHost(&buffer, total_size);
84-
if (ret != ACL_ERROR_NONE) {
85-
LOG(ERROR) << "Failed to allocate memory: " << ret;
86-
return nullptr;
83+
if (globalConfig().ascend_use_fabric_mem) {
84+
aclrtDrvMemHandle handle = nullptr;
85+
aclrtPhysicalMemProp prop = {};
86+
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
87+
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
88+
prop.memAttr = ACL_HBM_MEM_HUGE;
89+
prop.location.type = ACL_MEM_LOCATION_TYPE_HOST;
90+
prop.location.id = 0;
91+
prop.reserve = 0;
92+
auto ret = aclrtMallocPhysical(&handle, total_size, &prop, 0);
93+
if (ret != ACL_ERROR_NONE) {
94+
LOG(ERROR) << "Failed to allocate memory: " << ret;
95+
return nullptr;
96+
}
97+
void *va;
98+
ret = aclrtReserveMemAddress(&va, total_size, 0, nullptr, 1);
99+
if (ret != ACL_ERROR_NONE) {
100+
LOG(ERROR) << "Failed to reserve memory: " << ret;
101+
return nullptr;
102+
}
103+
ret = aclrtMapMem(va, total_size, 0, handle, 0);
104+
if (ret != ACL_ERROR_NONE) {
105+
LOG(ERROR) << "Failed to map memory: " << ret;
106+
return nullptr;
107+
}
108+
return va;
109+
} else {
110+
void *buffer = nullptr;
111+
auto ret = aclrtMallocHost(&buffer, total_size);
112+
if (ret != ACL_ERROR_NONE) {
113+
LOG(ERROR) << "Failed to allocate memory: " << ret;
114+
return nullptr;
115+
}
116+
return buffer;
87117
}
88-
return buffer;
89118
}
90119
#endif
91120
// Allocate aligned memory
@@ -95,7 +124,29 @@ void *allocate_buffer_allocator_memory(size_t total_size,
95124
void free_memory(const std::string &protocol, void *ptr) {
96125
#ifdef USE_ASCEND_DIRECT
97126
if (protocol == "ascend") {
98-
aclrtFreeHost(ptr);
127+
if (globalConfig().ascend_use_fabric_mem) {
128+
auto ret = aclrtUnmapMem(ptr);
129+
if (ret != ACL_ERROR_NONE) {
130+
LOG(ERROR) << "Failed to unmap memory: " << ptr;
131+
}
132+
aclrtReleaseMemAddress(ptr);
133+
if (ret != ACL_ERROR_NONE) {
134+
LOG(ERROR) << "Failed to release mem address: " << ptr;
135+
}
136+
aclrtDrvMemHandle handle;
137+
ret = aclrtMemRetainAllocationHandle(ptr, &handle);
138+
if (ret != ACL_ERROR_NONE) {
139+
LOG(ERROR) << "Failed to retain allocation handle: " << ptr;
140+
return;
141+
}
142+
ret = aclrtFreePhysical(handle);
143+
if (ret != ACL_ERROR_NONE) {
144+
LOG(ERROR) << "Failed to free physical handle: " << handle;
145+
return;
146+
}
147+
} else {
148+
aclrtFreeHost(ptr);
149+
}
99150
return;
100151
}
101152
#endif

mooncake-transfer-engine/include/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ struct GlobalConfig {
5858
bool enable_dest_device_affinity = false;
5959
size_t eic_max_block_size = 64UL * 1024 * 1024;
6060
EndpointStoreType endpoint_store_type = EndpointStoreType::SIEVE;
61+
bool ascend_use_fabric_mem = false;
6162
};
6263

6364
void loadGlobalConfig(GlobalConfig &config);

mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "transfer_metadata.h"
3333
#include "transfer_metadata_plugin.h"
3434
#include "transport/transport.h"
35+
#include "config.h"
3536

3637
namespace mooncake {
3738
namespace {
@@ -67,18 +68,6 @@ AscendDirectTransport::~AscendDirectTransport() {
6768
}
6869
connected_segments_.clear();
6970
}
70-
71-
// Deregister all memory
72-
std::lock_guard<std::mutex> mem_handle_lock(mem_handle_mutex_);
73-
for (const auto &[addr, mem_handle] : addr_to_mem_handle_) {
74-
auto status = adxl_->DeregisterMem(mem_handle);
75-
if (status != adxl::SUCCESS) {
76-
LOG(ERROR) << "Failed to deregister memory at address " << addr;
77-
} else {
78-
LOG(INFO) << "Deregistered memory at address " << addr;
79-
}
80-
}
81-
addr_to_mem_handle_.clear();
8271
adxl_->Finalize();
8372
}
8473

@@ -169,6 +158,12 @@ int AscendDirectTransport::InitAdxlEngine() {
169158
use_buffer_pool_ = false;
170159
}
171160
}
161+
if (globalConfig().ascend_use_fabric_mem) {
162+
options["EnableUseFabricMem"] = "1";
163+
use_buffer_pool_ = false;
164+
LOG(INFO) << "Use fabric mem is enabled.";
165+
globalConfig().ascend_use_fabric_mem = false;
166+
}
172167
auto adxl_engine_name =
173168
adxl::AscendString((host_ip + ":" + std::to_string(host_port)).c_str());
174169
auto status = adxl_->Initialize(adxl_engine_name, options);
@@ -178,7 +173,7 @@ int AscendDirectTransport::InitAdxlEngine() {
178173
}
179174
LOG(INFO) << "Success to initialize adxl engine:"
180175
<< adxl_engine_name.GetString()
181-
<< " with device_id:" << device_logic_id_;
176+
<< " with device_id:" << device_logic_id_ << ", pid:" << getpid();
182177
char *connect_timeout_str = std::getenv("ASCEND_CONNECT_TIMEOUT");
183178
if (connect_timeout_str) {
184179
std::optional<int32_t> connect_timeout =

0 commit comments

Comments
 (0)