Skip to content

Commit 3476145

Browse files
authored
[feat] Add configurable scattergatter option for pcstore (#483)
Add configurable scattergatter
1 parent a3658e6 commit 3476145

File tree

8 files changed

+67
-7
lines changed

8 files changed

+67
-7
lines changed

ucm/store/pcstore/cc/api/pcstore.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class PcStoreImpl : public PcStore {
4141
config.transferLocalRankSize, config.transferDeviceId, config.transferStreamNumber,
4242
config.kvcacheBlockSize, config.transferIoSize, config.transferIoDirect,
4343
config.transferBufferNumber, this->spaceMgr_.GetSpaceLayout(),
44-
config.transferTimeoutMs);
44+
config.transferTimeoutMs, config.transferScatterGatherEnable);
4545
if (status.Failure()) { return status.Underlying(); }
4646
}
4747
this->ShowConfig(config);
@@ -93,6 +93,7 @@ class PcStoreImpl : public PcStore {
9393
UC_INFO("Set UC::StreamNumber to {}.", config.transferStreamNumber);
9494
UC_INFO("Set UC::BufferNumber to {}.", config.transferBufferNumber);
9595
UC_INFO("Set UC::TimeoutMs to {}.", config.transferTimeoutMs);
96+
UC_INFO("Set UC::ScatterGatherEnable to {}.", config.transferScatterGatherEnable);
9697
}
9798

9899
private:

ucm/store/pcstore/cc/api/pcstore.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class PcStore : CCStore<TransTask> {
4242
size_t transferStreamNumber{8};
4343
size_t transferBufferNumber{4096};
4444
size_t transferTimeoutMs{30000};
45+
bool transferScatterGatherEnable{false};
4546

4647
Config(const std::vector<std::string>& storageBackends, const size_t kvcacheBlockSize,
4748
const bool transferEnable)

ucm/store/pcstore/cc/domain/trans/trans_manager.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace UC {
2929
Status TransManager::Setup(const size_t rankSize, const int32_t deviceId, const size_t streamNumber,
3030
const size_t blockSize, const size_t ioSize, const bool ioDirect,
3131
const size_t bufferNumber, const SpaceLayout* layout,
32-
const size_t timeoutMs)
32+
const size_t timeoutMs, const bool scatterGatherEnable)
3333
{
3434
auto s = Status::OK();
3535
if (rankSize > 1) {
@@ -38,7 +38,7 @@ Status TransManager::Setup(const size_t rankSize, const int32_t deviceId, const
3838
if (s.Failure()) { return s; }
3939
}
4040
s = this->queue_.Setup(deviceId, streamNumber, blockSize, ioSize, ioDirect, bufferNumber,
41-
layout, &this->failureSet_);
41+
layout, &this->failureSet_, scatterGatherEnable);
4242
if (s.Failure()) { return s; }
4343
this->rankSize_ = rankSize;
4444
this->timeoutMs_ = timeoutMs;

ucm/store/pcstore/cc/domain/trans/trans_manager.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class TransManager {
3333
public:
3434
Status Setup(const size_t rankSize, const int32_t deviceId, const size_t streamNumber,
3535
const size_t blockSize, const size_t ioSize, const bool ioDirect,
36-
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs);
36+
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs,
37+
const bool scatterGatherEnable);
3738
Status Submit(TransTask task, size_t& taskId) noexcept;
3839
Status Wait(const size_t taskId) noexcept;
3940
Status Check(const size_t taskId, bool& finish) noexcept;

ucm/store/pcstore/cc/domain/trans/trans_queue.cc

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ void TransQueue::FileWorker(BlockTask&& task)
7777

7878
Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize,
7979
const size_t ioSize, const bool ioDirect, const size_t bufferNumber,
80-
const SpaceLayout* layout, TaskSet* failureSet_)
80+
const SpaceLayout* layout, TaskSet* failureSet_,
81+
const bool scatterGatherEnable)
8182
{
8283
Trans::Device device;
8384
auto ts = device.Setup(deviceId);
@@ -91,6 +92,14 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons
9192
UC_ERROR("Failed to make buffer and stream on device({}).", deviceId);
9293
return Status::Error();
9394
}
95+
if (scatterGatherEnable) {
96+
devBuffer_ = device.MakeBuffer();
97+
smStream_ = device.MakeSMStream();
98+
if (!devBuffer_ || !smStream_) {
99+
UC_ERROR("Failed to make devBuffer and smStream on device({}).", deviceId);
100+
return Status::Error();
101+
}
102+
}
94103
ts = buffer_->MakeHostBuffers(blockSize, bufferNumber);
95104
if (ts.Failure()) {
96105
UC_ERROR("Failed({}) to make host buffer({},{}).", ts.ToString(), blockSize, bufferNumber);
@@ -108,13 +117,18 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons
108117
this->ioSize_ = ioSize;
109118
this->ioDirect_ = ioDirect;
110119
this->failureSet_ = failureSet_;
120+
this->scatterGatherEnable_ = scatterGatherEnable;
111121
return Status::OK();
112122
}
113123

114124
void TransQueue::Dispatch(TaskPtr task, WaiterPtr waiter)
115125
{
116126
if (task->type == TransTask::Type::DUMP) {
117-
this->DispatchDump(task, waiter);
127+
if (this->scatterGatherEnable_) {
128+
this->DispatchSatterGatherDump(task, waiter);
129+
} else {
130+
this->DispatchDump(task, waiter);
131+
}
118132
return;
119133
}
120134
task->ForEachGroup(
@@ -171,4 +185,40 @@ void TransQueue::DispatchDump(TaskPtr task, WaiterPtr waiter)
171185
}
172186
}
173187

188+
void TransQueue::DispatchSatterGatherDump(TaskPtr task, WaiterPtr waiter)
189+
{
190+
std::vector<BlockTask> blocks;
191+
blocks.reserve(task->GroupNumber());
192+
std::vector<std::shared_ptr<void>> addrs;
193+
addrs.reserve(task->GroupNumber());
194+
task->ForEachGroup(
195+
[task, &blocks, &addrs, this](const std::string& block, std::vector<uintptr_t>& shards) {
196+
BlockTask blockTask;
197+
blockTask.owner = task->id;
198+
blockTask.block = block;
199+
blockTask.type = task->type;
200+
auto number = shards.size();
201+
auto bufferSize = this->ioSize_ * number;
202+
blockTask.buffer = buffer_->GetHostBuffer(bufferSize);
203+
std::swap(blockTask.shards, shards);
204+
auto device = (void*)blockTask.shards.data();
205+
auto host = blockTask.buffer.get();
206+
auto devAddr = this->devBuffer_->MakeDeviceBuffer(sizeof(void*) * number);
207+
smStream_->HostToDeviceAsync(device, devAddr.get(), sizeof(void*) * number);
208+
smStream_->DeviceToHostAsync((void**)devAddr.get(), host, this->ioSize_, number);
209+
addrs.push_back(devAddr);
210+
blocks.push_back(std::move(blockTask));
211+
});
212+
auto s = smStream_->Synchronized();
213+
if (s.Failure()) { this->failureSet_->Insert(task->id); }
214+
for (auto&& block : blocks) {
215+
if (s.Failure()) {
216+
waiter->Done(nullptr);
217+
return;
218+
}
219+
this->filePool_.Push(std::move(block));
220+
waiter->Done([task, ioSize = this->ioSize_] { UC_DEBUG("{}", task->Epilog(ioSize)); });
221+
}
222+
}
223+
174224
} // namespace UC

ucm/store/pcstore/cc/domain/trans/trans_queue.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,23 @@ class TransQueue {
5151
public:
5252
Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize,
5353
const size_t ioSize, const bool ioDirect, const size_t bufferNumber,
54-
const SpaceLayout* layout, TaskSet* failureSet_);
54+
const SpaceLayout* layout, TaskSet* failureSet_, const bool scatterGatherEnable);
5555
void Dispatch(TaskPtr task, WaiterPtr waiter);
5656
void DispatchDump(TaskPtr task, WaiterPtr waiter);
57+
void DispatchSatterGatherDump(TaskPtr task, WaiterPtr waiter);
5758

5859
private:
5960
std::unique_ptr<Trans::Buffer> buffer_{nullptr};
6061
std::unique_ptr<Trans::Stream> stream_{nullptr};
62+
std::unique_ptr<Trans::Buffer> devBuffer_{nullptr};
63+
std::unique_ptr<Trans::Stream> smStream_{nullptr};
6164
const SpaceLayout* layout_;
6265
size_t ioSize_;
6366
bool ioDirect_;
6467
ThreadPool<BlockTask> devPool_;
6568
ThreadPool<BlockTask> filePool_;
6669
TaskSet* failureSet_;
70+
bool scatterGatherEnable_;
6771
};
6872

6973
} // namespace UC

ucm/store/pcstore/cpy/pcstore.py.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ PYBIND11_MODULE(ucmpcstore, module)
101101
config.def_readwrite("transferIoSize", &UC::PcStorePy::Config::transferIoSize);
102102
config.def_readwrite("transferBufferNumber", &UC::PcStorePy::Config::transferBufferNumber);
103103
config.def_readwrite("transferTimeoutMs", &UC::PcStorePy::Config::transferTimeoutMs);
104+
config.def_readwrite("transferScatterGatherEnable",
105+
&UC::PcStorePy::Config::transferScatterGatherEnable);
104106
store.def(py::init<>());
105107
store.def("CCStoreImpl", &UC::PcStorePy::CCStoreImpl);
106108
store.def("Setup", &UC::PcStorePy::Setup);

ucm/store/pcstore/pcstore_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, config: Dict):
5353
param.transferStreamNumber = config.get("stream_number", 8)
5454
param.transferBufferNumber = config.get("buffer_number", 4096)
5555
param.transferLocalRankSize = config.get("local_rank_size", 8)
56+
param.transferScatterGatherEnable = config.get("use_scatter_gatter", False)
5657
ret = self.store.Setup(param)
5758
if ret != 0:
5859
msg = f"Failed to initialize ucmpcstore, errcode: {ret}."

0 commit comments

Comments
 (0)