@@ -77,7 +77,8 @@ void TransQueue::FileWorker(BlockTask&& task)
7777
7878Status TransQueue::Setup (const int32_t deviceId, const size_t streamNumber, const size_t blockSize,
7979 const size_t ioSize, const bool ioDirect, const size_t bufferNumber,
80- const SpaceLayout* layout, TaskSet* failureSet_)
80+ const SpaceLayout* layout, TaskSet* failureSet_,
81+ const bool scatterGatherEnable)
8182{
8283 Trans::Device device;
8384 auto ts = device.Setup (deviceId);
@@ -91,6 +92,14 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons
9192 UC_ERROR (" Failed to make buffer and stream on device({})." , deviceId);
9293 return Status::Error ();
9394 }
95+ if (scatterGatherEnable) {
96+ devBuffer_ = device.MakeBuffer ();
97+ smStream_ = device.MakeSMStream ();
98+ if (!devBuffer_ || !smStream_) {
99+ UC_ERROR (" Failed to make devBuffer and smStream on device({})." , deviceId);
100+ return Status::Error ();
101+ }
102+ }
94103 ts = buffer_->MakeHostBuffers (blockSize, bufferNumber);
95104 if (ts.Failure ()) {
96105 UC_ERROR (" Failed({}) to make host buffer({},{})." , ts.ToString (), blockSize, bufferNumber);
@@ -108,13 +117,18 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons
108117 this ->ioSize_ = ioSize;
109118 this ->ioDirect_ = ioDirect;
110119 this ->failureSet_ = failureSet_;
120+ this ->scatterGatherEnable_ = scatterGatherEnable;
111121 return Status::OK ();
112122}
113123
114124void TransQueue::Dispatch (TaskPtr task, WaiterPtr waiter)
115125{
116126 if (task->type == TransTask::Type::DUMP) {
117- this ->DispatchDump (task, waiter);
127+ if (this ->scatterGatherEnable_ ) {
128+ this ->DispatchSatterGatherDump (task, waiter);
129+ } else {
130+ this ->DispatchDump (task, waiter);
131+ }
118132 return ;
119133 }
120134 task->ForEachGroup (
@@ -171,4 +185,40 @@ void TransQueue::DispatchDump(TaskPtr task, WaiterPtr waiter)
171185 }
172186}
173187
188+ void TransQueue::DispatchSatterGatherDump (TaskPtr task, WaiterPtr waiter)
189+ {
190+ std::vector<BlockTask> blocks;
191+ blocks.reserve (task->GroupNumber ());
192+ std::vector<std::shared_ptr<void >> addrs;
193+ addrs.reserve (task->GroupNumber ());
194+ task->ForEachGroup (
195+ [task, &blocks, &addrs, this ](const std::string& block, std::vector<uintptr_t >& shards) {
196+ BlockTask blockTask;
197+ blockTask.owner = task->id ;
198+ blockTask.block = block;
199+ blockTask.type = task->type ;
200+ auto number = shards.size ();
201+ auto bufferSize = this ->ioSize_ * number;
202+ blockTask.buffer = buffer_->GetHostBuffer (bufferSize);
203+ std::swap (blockTask.shards , shards);
204+ auto device = (void *)blockTask.shards .data ();
205+ auto host = blockTask.buffer .get ();
206+ auto devAddr = this ->devBuffer_ ->MakeDeviceBuffer (sizeof (void *) * number);
207+ smStream_->HostToDeviceAsync (device, devAddr.get (), sizeof (void *) * number);
208+ smStream_->DeviceToHostAsync ((void **)devAddr.get (), host, this ->ioSize_ , number);
209+ addrs.push_back (devAddr);
210+ blocks.push_back (std::move (blockTask));
211+ });
212+ auto s = smStream_->Synchronized ();
213+ if (s.Failure ()) { this ->failureSet_ ->Insert (task->id ); }
214+ for (auto && block : blocks) {
215+ if (s.Failure ()) {
216+ waiter->Done (nullptr );
217+ return ;
218+ }
219+ this ->filePool_ .Push (std::move (block));
220+ waiter->Done ([task, ioSize = this ->ioSize_ ] { UC_DEBUG (" {}" , task->Epilog (ioSize)); });
221+ }
222+ }
223+
174224} // namespace UC
0 commit comments