Skip to content

Commit 0511e1e

Browse files
committed
consumer loop for queue
1 parent 6347050 commit 0511e1e

File tree

5 files changed

+50
-90
lines changed

5 files changed

+50
-90
lines changed

ucm/shared/infra/template/spsc_ring_queue.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <array>
2828
#include <atomic>
2929
#include <cstddef>
30+
#include <functional>
3031
#include <thread>
3132

3233
namespace UC {
@@ -74,6 +75,33 @@ class SpscRingQueue {
7475
tail_.store((currentTail + 1) & (N - 1), std::memory_order_release);
7576
return true;
7677
}
78+
79+
template <typename ConsumerHandler, typename... Args>
80+
void ConsumerLoop(const std::atomic_bool& stop, ConsumerHandler&& handler, Args&&... args)
81+
{
82+
constexpr size_t kSpinLimit = 16;
83+
constexpr size_t kTaskBatch = 64;
84+
size_t spinCount = 0;
85+
size_t taskCount = 0;
86+
T task;
87+
while (!stop.load(std::memory_order_relaxed)) {
88+
if (TryPop(task)) {
89+
spinCount = 0;
90+
std::invoke(handler, std::forward<Args>(args)..., std::move(task));
91+
if (++taskCount % kTaskBatch == 0) {
92+
if (stop.load(std::memory_order_acquire)) { break; }
93+
}
94+
continue;
95+
}
96+
if (++spinCount < kSpinLimit) {
97+
std::this_thread::yield();
98+
} else {
99+
if (stop.load(std::memory_order_acquire)) { break; }
100+
std::this_thread::sleep_for(std::chrono::microseconds(100));
101+
spinCount = 0;
102+
}
103+
}
104+
}
77105
};
78106

79107
} // namespace UC

ucm/store/cache/cc/dump_queue.cc

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727

2828
namespace UC::CacheStore {
2929

30-
static constexpr size_t spinThreshold = 1000;
31-
static constexpr auto tryPopInterval = std::chrono::microseconds(100);
32-
3330
DumpQueue::~DumpQueue()
3431
{
3532
stop_.store(true);
@@ -76,27 +73,13 @@ void DumpQueue::DispatchStage(int32_t deviceId, size_t tensorSize, std::promise<
7673
return;
7774
}
7875
started.set_value(Status::OK());
79-
size_t spinCount = 0;
80-
TaskPair pair;
81-
while (!stop_.load(std::memory_order_acquire)) {
82-
if (waiting_.TryPop(pair)) {
83-
spinCount = 0;
84-
DispatchOneTask(stream.get(), tensorSize, std::move(pair.first),
85-
std::move(pair.second));
86-
} else {
87-
if (++spinCount < spinThreshold) {
88-
std::this_thread::yield();
89-
} else {
90-
std::this_thread::sleep_for(tryPopInterval);
91-
spinCount = 0;
92-
}
93-
}
94-
}
76+
waiting_.ConsumerLoop(stop_, &DumpQueue::DispatchOneTask, this, stream.get(), tensorSize);
9577
}
9678

97-
void DumpQueue::DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task,
98-
WaiterPtr waiter)
79+
void DumpQueue::DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPair&& pair)
9980
{
81+
auto& task = pair.first;
82+
auto& waiter = pair.second;
10083
if (!failureSet_->Contains(task->id)) {
10184
auto s = DumpOneTask(stream, tensorSize, task);
10285
if (s.Failure()) [[unlikely]] { failureSet_->Insert(task->id); }
@@ -152,34 +135,17 @@ Status DumpQueue::DumpOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr
152135

153136
void DumpQueue::BackendDumpStage()
154137
{
155-
size_t spinCount = 0;
156-
ShardTask task;
157-
while (!stop_.load(std::memory_order_acquire)) {
158-
if (dumping_.TryPop(task)) {
159-
spinCount = 0;
160-
HandleOneShardTask(task);
161-
} else {
162-
if (++spinCount < spinThreshold) {
163-
std::this_thread::yield();
164-
} else {
165-
std::this_thread::sleep_for(tryPopInterval);
166-
spinCount = 0;
138+
dumping_.ConsumerLoop(stop_, [this](auto&& task) {
139+
static Detail::TaskHandle finishedBackendTaskHandle = 0;
140+
if (task.backendTaskHandle > finishedBackendTaskHandle) {
141+
auto s = backend_->Wait(task.backendTaskHandle);
142+
if (s.Failure()) {
143+
UC_ERROR("Failed({}) to wait backend task({}).", s, task.backendTaskHandle);
144+
return;
167145
}
146+
finishedBackendTaskHandle = task.backendTaskHandle;
168147
}
169-
}
170-
}
171-
172-
void DumpQueue::HandleOneShardTask(ShardTask& task)
173-
{
174-
static Detail::TaskHandle finishedBackendTaskHandle = 0;
175-
if (task.backendTaskHandle > finishedBackendTaskHandle) {
176-
auto s = backend_->Wait(task.backendTaskHandle);
177-
if (s.Failure()) {
178-
UC_ERROR("Failed({}) to wait backend task({}).", s, task.backendTaskHandle);
179-
return;
180-
}
181-
finishedBackendTaskHandle = task.backendTaskHandle;
182-
}
148+
});
183149
}
184150

185151
} // namespace UC::CacheStore

ucm/store/cache/cc/dump_queue.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ class DumpQueue {
6363

6464
private:
6565
void DispatchStage(int32_t deviceId, size_t tensorSize, std::promise<Status>& started);
66-
void DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task, WaiterPtr waiter);
66+
void DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPair&& pair);
6767
Status DumpOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task);
6868
void BackendDumpStage();
69-
void HandleOneShardTask(ShardTask& task);
7069
};
7170

7271
} // namespace UC::CacheStore

ucm/store/cache/cc/load_queue.cc

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,12 @@ void LoadQueue::Submit(TaskPtr task, WaiterPtr waiter)
5757
waiter->Done();
5858
}
5959

60-
void LoadQueue::DispatchStage()
61-
{
62-
constexpr size_t spinThreshold = 1000;
63-
constexpr auto tryPopInterval = std::chrono::microseconds(100);
64-
size_t spinCount = 0;
65-
TaskPair pair;
66-
while (!stop_.load(std::memory_order_acquire)) {
67-
if (waiting_.TryPop(pair)) {
68-
spinCount = 0;
69-
DispatchOneTask(std::move(pair.first), std::move(pair.second));
70-
} else {
71-
if (++spinCount < spinThreshold) {
72-
std::this_thread::yield();
73-
} else {
74-
std::this_thread::sleep_for(tryPopInterval);
75-
spinCount = 0;
76-
}
77-
}
78-
}
79-
}
60+
void LoadQueue::DispatchStage() { waiting_.ConsumerLoop(stop_, &LoadQueue::DispatchOneTask, this); }
8061

81-
void LoadQueue::DispatchOneTask(TaskPtr task, WaiterPtr waiter)
62+
void LoadQueue::DispatchOneTask(TaskPair&& pair)
8263
{
64+
auto& task = pair.first;
65+
auto& waiter = pair.second;
8366
if (failureSet_->Contains(task->id)) {
8467
waiter->Done();
8568
return;
@@ -133,26 +116,10 @@ void LoadQueue::TransferStage(int32_t deviceId, size_t tensorSize, std::promise<
133116
return;
134117
}
135118
started.set_value(Status::OK());
136-
constexpr size_t spinThreshold = 1000;
137-
constexpr auto tryPopInterval = std::chrono::microseconds(100);
138-
size_t spinCount = 0;
139-
ShardTask task;
140-
while (!stop_.load(std::memory_order_acquire)) {
141-
if (running_.TryPop(task)) {
142-
spinCount = 0;
143-
TransferOneTask(stream.get(), tensorSize, task);
144-
} else {
145-
if (++spinCount < spinThreshold) {
146-
std::this_thread::yield();
147-
} else {
148-
std::this_thread::sleep_for(tryPopInterval);
149-
spinCount = 0;
150-
}
151-
}
152-
}
119+
running_.ConsumerLoop(stop_, &LoadQueue::TransferOneTask, this, stream.get(), tensorSize);
153120
}
154121

155-
void LoadQueue::TransferOneTask(Trans::Stream* stream, size_t tensorSize, ShardTask& task)
122+
void LoadQueue::TransferOneTask(Trans::Stream* stream, size_t tensorSize, ShardTask&& task)
156123
{
157124
static Detail::TaskHandle finishedBackendTaskHandle = 0;
158125
if (task.backendTaskHandle > finishedBackendTaskHandle) {

ucm/store/cache/cc/load_queue.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ class LoadQueue {
6666

6767
private:
6868
void DispatchStage();
69-
void DispatchOneTask(TaskPtr task, WaiterPtr waiter);
69+
void DispatchOneTask(TaskPair&& pair);
7070
void TransferStage(int32_t deviceId, size_t tensorSize, std::promise<Status>& started);
71-
void TransferOneTask(Trans::Stream* stream, size_t tensorSize, ShardTask& task);
71+
void TransferOneTask(Trans::Stream* stream, size_t tensorSize, ShardTask&& task);
7272
};
7373

7474
} // namespace UC::CacheStore

0 commit comments

Comments
 (0)