diff --git a/ucm/shared/infra/thread/thread_pool.h b/ucm/shared/infra/thread/thread_pool.h index baa514ed7..e2aecdef2 100644 --- a/ucm/shared/infra/thread/thread_pool.h +++ b/ucm/shared/infra/thread/thread_pool.h @@ -24,12 +24,17 @@ #ifndef UNIFIEDCACHE_INFRA_THREAD_POOL_H #define UNIFIEDCACHE_INFRA_THREAD_POOL_H +#include #include #include #include #include +#include #include +#include #include +#include +#include namespace UC { @@ -37,8 +42,25 @@ template class ThreadPool { using WorkerInitFn = std::function; using WorkerFn = std::function; + using WorkerTimeoutFn = std::function; using WorkerExitFn = std::function; + class StopToken { + std::shared_ptr> flag_ = std::make_shared>(false); + + public: + void RequestStop() noexcept { this->flag_->store(true, std::memory_order_relaxed); } + bool StopRequested() const noexcept { return this->flag_->load(std::memory_order_relaxed); } + }; + + struct Worker { + ssize_t tid; + std::thread th; + StopToken stop; + std::weak_ptr current; + std::atomic tp{}; + }; + public: ThreadPool() = default; ThreadPool(const ThreadPool&) = delete; @@ -46,12 +68,13 @@ class ThreadPool { ~ThreadPool() { { - std::unique_lock lk(this->mtx_); + std::lock_guard lock(this->taskMtx_); this->stop_ = true; this->cv_.notify_all(); } - for (auto& w : this->workers_) { - if (w.joinable()) { w.join(); } + if (this->monitor_.joinable()) { this->monitor_.join(); } + for (auto& worker : this->workers_) { + if (worker->th.joinable()) { worker->th.join(); } } } ThreadPool& SetWorkerFn(WorkerFn&& fn) @@ -69,6 +92,14 @@ class ThreadPool { this->exitFn_ = std::move(fn); return *this; } + ThreadPool& SetWorkerTimeoutFn(WorkerTimeoutFn&& fn, const size_t timeoutMs, + const size_t intervalMs = 1000) + { + this->timeoutFn_ = std::move(fn); + this->timeoutMs_ = timeoutMs; + this->intervalMs_ = intervalMs; + return *this; + } ThreadPool& SetNWorker(const size_t nWorker) { this->nWorker_ = nWorker; @@ -77,64 +108,126 @@ class ThreadPool { bool Run() { if (this->nWorker_ == 0) { return false; } - if (!this->fn_) { return false; } - std::list> start(this->nWorker_); - std::list> fut; - for (auto& s : start) { - fut.push_back(s.get_future()); - this->workers_.emplace_back([&] { this->Worker(s); }); + if (this->fn_ == nullptr) { return false; } + this->workers_.reserve(this->nWorker_); + for (size_t i = 0; i < this->nWorker_; i++) { + if (!this->AddOneWorker()) { return false; } } - auto success = true; - for (auto& f : fut) { - if (!f.get()) { success = false; } + if (this->timeoutMs_ > 0) { + this->monitor_ = std::thread([this] { this->MonitorLoop(); }); } - return success; + return true; } void Push(std::list& tasks) noexcept { - std::unique_lock lk(this->mtx_); + std::unique_lock lock(this->taskMtx_); this->taskQ_.splice(this->taskQ_.end(), tasks); this->cv_.notify_all(); } void Push(Task&& task) noexcept { - std::unique_lock lk(this->mtx_); + std::unique_lock lock(this->taskMtx_); this->taskQ_.push_back(std::move(task)); this->cv_.notify_one(); } private: - void Worker(std::promise& started) noexcept + bool AddOneWorker() + { + try { + auto worker = std::make_shared(); + std::promise prom; + auto fut = prom.get_future(); + worker->th = std::thread([this, worker, &prom] { this->WorkerLoop(prom, worker); }); + auto success = fut.get(); + if (!success) { return false; } + this->workers_.push_back(worker); + return true; + } catch (...) { + return false; + } + } + void WorkerLoop(std::promise& prom, std::shared_ptr worker) { + worker->tid = syscall(SYS_gettid); WorkerArgs args = nullptr; auto success = true; if (this->initFn_) { success = this->initFn_(args); } - started.set_value(success); + prom.set_value(success); while (success) { - std::unique_lock lk(this->mtx_); - this->cv_.wait(lk, [this] { return this->stop_ || !this->taskQ_.empty(); }); - if (this->stop_) { break; } - if (this->taskQ_.empty()) { continue; } - auto task = std::make_shared(std::move(this->taskQ_.front())); - this->taskQ_.pop_front(); - lk.unlock(); + std::shared_ptr task = nullptr; + { + std::unique_lock lock(this->taskMtx_); + this->cv_.wait(lock, [this, worker] { + return this->stop_ || worker->stop.StopRequested() || !this->taskQ_.empty(); + }); + if (this->stop_ || worker->stop.StopRequested()) { break; } + if (this->taskQ_.empty()) { continue; } + task = std::make_shared(std::move(this->taskQ_.front())); + this->taskQ_.pop_front(); + } + worker->current = task; + worker->tp.store(std::chrono::steady_clock::now(), std::memory_order_relaxed); this->fn_(*task, args); + if (worker->stop.StopRequested()) { break; } + worker->current.reset(); + worker->tp.store({}, std::memory_order_relaxed); } if (this->exitFn_) { this->exitFn_(args); } } + void MonitorLoop() + { + const auto interval = std::chrono::milliseconds(this->intervalMs_); + while (true) { + { + std::unique_lock lock(this->taskMtx_); + this->cv_.wait_for(lock, interval, [this] { return this->stop_; }); + if (this->stop_) { break; } + } + size_t nWorker = this->Monitor(); + for (size_t i = nWorker; i < this->nWorker_; i++) { (void)this->AddOneWorker(); } + } + } + + size_t Monitor() + { + using namespace std::chrono; + const auto timeout = milliseconds(this->timeoutMs_); + size_t nWorker = 0; + for (auto it = this->workers_.begin(); it != this->workers_.end();) { + auto tp = (*it)->tp.load(std::memory_order_relaxed); + auto task = (*it)->current.lock(); + auto now = steady_clock::now(); + if (task && tp != steady_clock::time_point{} && now - tp > timeout) { + if (this->timeoutFn_) { this->timeoutFn_(*task, (*it)->tid); } + (*it)->stop.RequestStop(); + if ((*it)->th.joinable()) { (*it)->th.detach(); } + it = this->workers_.erase(it); + } else { + it++; + } + nWorker++; + } + return nWorker; + } + private: + WorkerInitFn initFn_{nullptr}; + WorkerFn fn_{nullptr}; + WorkerTimeoutFn timeoutFn_{nullptr}; + WorkerExitFn exitFn_{nullptr}; + size_t timeoutMs_{0}; + size_t intervalMs_{0}; + size_t nWorker_{0}; bool stop_{false}; - size_t nWorker_{1}; - std::list workers_; - WorkerInitFn initFn_; - WorkerFn fn_; - WorkerExitFn exitFn_; + std::vector> workers_; + std::thread monitor_; + std::mutex taskMtx_; std::list taskQ_; - std::mutex mtx_; std::condition_variable cv_; }; -} // namespace UC +} // namespace UC #endif diff --git a/ucm/shared/test/case/infra/thread_pool_test.cc b/ucm/shared/test/case/infra/thread_pool_test.cc new file mode 100644 index 000000000..c3805c5dd --- /dev/null +++ b/ucm/shared/test/case/infra/thread_pool_test.cc @@ -0,0 +1,101 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "thread/thread_pool.h" +#include +#include +#include +#include +#include "thread/latch.h" + +class UCThreadPoolTest : public ::testing::Test {}; + +TEST_F(UCThreadPoolTest, TimeoutDetection) +{ + struct TestTask { + int taskId; + std::atomic* finished; + std::atomic* timeout; + }; + + constexpr size_t nWorker = 2; + constexpr size_t timeoutMs = 20; + std::atomic timeoutCount{0}; + std::atomic taskFinished{false}; + std::atomic taskTimeout{false}; + + UC::ThreadPool threadPool; + threadPool.SetNWorker(nWorker) + .SetWorkerFn([](TestTask& task, const auto&) { + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + *(task.finished) = true; + }) + .SetWorkerTimeoutFn( + [&](TestTask& task, const auto) { + timeoutCount++; + task.timeout->store(true); + }, + timeoutMs, 10) + .Run(); + std::list tasks{ + {1, &taskFinished, &taskTimeout} + }; + threadPool.Push(tasks); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + ASSERT_GT(timeoutCount.load(), 0); + ASSERT_TRUE(taskTimeout.load()); +} + +TEST_F(UCThreadPoolTest, SimulatedFileSystemHang) +{ + struct TestTask { + std::atomic* simulatingHang; + }; + + std::atomic hangDetected{0}; + constexpr size_t hangTimeoutMs = 20; + std::atomic taskHang{true}; + + UC::ThreadPool threadPool; + threadPool.SetNWorker(1) + .SetWorkerFn([](TestTask& task, const auto&) { + std::mutex fakeMutex; + std::unique_lock fakelock(fakeMutex); + std::condition_variable fakeCond; + while (*(task.simulatingHang)) { + fakeCond.wait_for(fakelock, std::chrono::milliseconds(10)); // waiting forever + } + }) + .SetWorkerTimeoutFn( + [&](TestTask& task, const auto) { + hangDetected++; + *(task.simulatingHang) = false; // stop simulating hang + }, + hangTimeoutMs, 10) + .Run(); + std::list tasks{{&taskHang}}; + threadPool.Push(tasks); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + EXPECT_GT(hangDetected.load(), 0); +} \ No newline at end of file diff --git a/ucm/store/pcstore/cc/domain/trans/trans_manager.cc b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc index d2106ab63..aeb30543b 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_manager.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc @@ -39,7 +39,7 @@ Status TransManager::Setup(const size_t rankSize, const int32_t deviceId, const if (s.Failure()) { return s; } } s = this->queue_.Setup(deviceId, streamNumber, blockSize, ioSize, ioDirect, bufferNumber, - layout, &this->failureSet_, scatterGatherEnable); + layout, &this->failureSet_, scatterGatherEnable, timeoutMs); if (s.Failure()) { return s; } this->rankSize_ = rankSize; this->timeoutMs_ = timeoutMs; diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc index 993b60299..a58a196aa 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc @@ -51,10 +51,10 @@ void TransQueue::DeviceWorker(BlockTask&& task) return; } -void TransQueue::FileWorker(BlockTask&& task) +void TransQueue::FileWorker(BlockTask& task) { if (this->failureSet_->Contains(task.owner)) { - task.done(false); + if (task.type != TransTask::Type::DUMP) { task.done(false); } return; } auto hostPtr = (uintptr_t)task.buffer.get(); @@ -75,10 +75,22 @@ void TransQueue::FileWorker(BlockTask&& task) task.done(false); } +void TransQueue::FileWorkerTimeout(BlockTask& task) +{ + static size_t lastTaskId = 0; + if (lastTaskId != task.owner) { + lastTaskId = task.owner; + UC_WARN("Task({}) timeout.", task.owner); + } + + if (task.type != TransTask::Type::DUMP) { this->failureSet_->Insert(task.owner); } + if (task.done) { task.done(false); } +} + Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize, const size_t ioSize, const bool ioDirect, const size_t bufferNumber, const SpaceLayout* layout, TaskSet* failureSet_, - const bool scatterGatherEnable) + const bool scatterGatherEnable, const size_t timeoutMs) { Trans::Device device; auto ts = device.Setup(deviceId); @@ -112,11 +124,14 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons return ts.Success(); }) .SetWorkerFn([this](auto t, auto) { this->DeviceWorker(std::move(t)); }) + .SetNWorker(streamNumber) .Run(); if (!success) { return Status::Error(); } - success = this->filePool_.SetWorkerFn([this](auto t, auto) { this->FileWorker(std::move(t)); }) - .SetNWorker(streamNumber) - .Run(); + success = + this->filePool_.SetWorkerFn([this](auto t, auto) { this->FileWorker(t); }) + .SetWorkerTimeoutFn([this](auto t, auto) { this->FileWorkerTimeout(t); }, timeoutMs) + .SetNWorker(streamNumber) + .Run(); if (!success) { return Status::Error(); } this->layout_ = layout; this->ioSize_ = ioSize; @@ -226,4 +241,4 @@ void TransQueue::DispatchSatterGatherDump(TaskPtr task, WaiterPtr waiter) } } -} // namespace UC +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.h b/ucm/store/pcstore/cc/domain/trans/trans_queue.h index 377f09d5d..abf6b1f78 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_queue.h +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.h @@ -46,12 +46,14 @@ class TransQueue { std::function done; }; void DeviceWorker(BlockTask&& task); - void FileWorker(BlockTask&& task); + void FileWorker(BlockTask& task); + void FileWorkerTimeout(BlockTask& task); public: Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize, const size_t ioSize, const bool ioDirect, const size_t bufferNumber, - const SpaceLayout* layout, TaskSet* failureSet_, const bool scatterGatherEnable); + const SpaceLayout* layout, TaskSet* failureSet_, const bool scatterGatherEnable, + const size_t timeoutMs); void Dispatch(TaskPtr task, WaiterPtr waiter); void DispatchDump(TaskPtr task, WaiterPtr waiter); void DispatchSatterGatherDump(TaskPtr task, WaiterPtr waiter); @@ -70,6 +72,6 @@ class TransQueue { bool scatterGatherEnable_; }; -} // namespace UC +} // namespace UC #endif