|
27 | 27 |
|
28 | 28 | namespace UC::CacheStore { |
29 | 29 |
|
30 | | -static constexpr size_t spinThreshold = 1000; |
31 | | -static constexpr auto tryPopInterval = std::chrono::microseconds(100); |
32 | | - |
33 | 30 | DumpQueue::~DumpQueue() |
34 | 31 | { |
35 | 32 | stop_.store(true); |
@@ -76,27 +73,13 @@ void DumpQueue::DispatchStage(int32_t deviceId, size_t tensorSize, std::promise< |
76 | 73 | return; |
77 | 74 | } |
78 | 75 | started.set_value(Status::OK()); |
79 | | - size_t spinCount = 0; |
80 | | - TaskPair pair; |
81 | | - while (!stop_.load(std::memory_order_acquire)) { |
82 | | - if (waiting_.TryPop(pair)) { |
83 | | - spinCount = 0; |
84 | | - DispatchOneTask(stream.get(), tensorSize, std::move(pair.first), |
85 | | - std::move(pair.second)); |
86 | | - } else { |
87 | | - if (++spinCount < spinThreshold) { |
88 | | - std::this_thread::yield(); |
89 | | - } else { |
90 | | - std::this_thread::sleep_for(tryPopInterval); |
91 | | - spinCount = 0; |
92 | | - } |
93 | | - } |
94 | | - } |
| 76 | + waiting_.ConsumerLoop(stop_, &DumpQueue::DispatchOneTask, this, stream.get(), tensorSize); |
95 | 77 | } |
96 | 78 |
|
97 | | -void DumpQueue::DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task, |
98 | | - WaiterPtr waiter) |
| 79 | +void DumpQueue::DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPair&& pair) |
99 | 80 | { |
| 81 | + auto& task = pair.first; |
| 82 | + auto& waiter = pair.second; |
100 | 83 | if (!failureSet_->Contains(task->id)) { |
101 | 84 | auto s = DumpOneTask(stream, tensorSize, task); |
102 | 85 | if (s.Failure()) [[unlikely]] { failureSet_->Insert(task->id); } |
@@ -152,34 +135,17 @@ Status DumpQueue::DumpOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr |
152 | 135 |
|
153 | 136 | void DumpQueue::BackendDumpStage() |
154 | 137 | { |
155 | | - size_t spinCount = 0; |
156 | | - ShardTask task; |
157 | | - while (!stop_.load(std::memory_order_acquire)) { |
158 | | - if (dumping_.TryPop(task)) { |
159 | | - spinCount = 0; |
160 | | - HandleOneShardTask(task); |
161 | | - } else { |
162 | | - if (++spinCount < spinThreshold) { |
163 | | - std::this_thread::yield(); |
164 | | - } else { |
165 | | - std::this_thread::sleep_for(tryPopInterval); |
166 | | - spinCount = 0; |
| 138 | + dumping_.ConsumerLoop(stop_, [this](auto&& task) { |
| 139 | + static Detail::TaskHandle finishedBackendTaskHandle = 0; |
| 140 | + if (task.backendTaskHandle > finishedBackendTaskHandle) { |
| 141 | + auto s = backend_->Wait(task.backendTaskHandle); |
| 142 | + if (s.Failure()) { |
| 143 | + UC_ERROR("Failed({}) to wait backend task({}).", s, task.backendTaskHandle); |
| 144 | + return; |
167 | 145 | } |
| 146 | + finishedBackendTaskHandle = task.backendTaskHandle; |
168 | 147 | } |
169 | | - } |
170 | | -} |
171 | | - |
172 | | -void DumpQueue::HandleOneShardTask(ShardTask& task) |
173 | | -{ |
174 | | - static Detail::TaskHandle finishedBackendTaskHandle = 0; |
175 | | - if (task.backendTaskHandle > finishedBackendTaskHandle) { |
176 | | - auto s = backend_->Wait(task.backendTaskHandle); |
177 | | - if (s.Failure()) { |
178 | | - UC_ERROR("Failed({}) to wait backend task({}).", s, task.backendTaskHandle); |
179 | | - return; |
180 | | - } |
181 | | - finishedBackendTaskHandle = task.backendTaskHandle; |
182 | | - } |
| 148 | + }); |
183 | 149 | } |
184 | 150 |
|
185 | 151 | } // namespace UC::CacheStore |
0 commit comments