|
| 1 | +/** |
| 2 | + * MIT License |
| 3 | + * |
| 4 | + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. |
| 5 | + * |
| 6 | + * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | + * of this software and associated documentation files (the "Software"), to deal |
| 8 | + * in the Software without restriction, including without limitation the rights |
| 9 | + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | + * copies of the Software, and to permit persons to whom the Software is |
| 11 | + * furnished to do so, subject to the following conditions: |
| 12 | + * |
| 13 | + * The above copyright notice and this permission notice shall be included in all |
| 14 | + * copies or substantial portions of the Software. |
| 15 | + * |
| 16 | + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | + * SOFTWARE. |
| 23 | + * */ |
| 24 | +#include "dump_queue.h" |
| 25 | +#include "logger/logger.h" |
| 26 | +#include "trans/device.h" |
| 27 | + |
| 28 | +namespace UC::CacheStore { |
| 29 | + |
| 30 | +static constexpr size_t spinThreshold = 1000; |
| 31 | +static constexpr auto tryPopInterval = std::chrono::microseconds(100); |
| 32 | + |
| 33 | +DumpQueue::~DumpQueue() |
| 34 | +{ |
| 35 | + stop_.store(true); |
| 36 | + if (dispatcher_.joinable()) { dispatcher_.join(); } |
| 37 | + if (dumper_.joinable()) { dumper_.join(); } |
| 38 | +} |
| 39 | + |
| 40 | +Status DumpQueue::Setup(const Config& config, TaskIdSet* failureSet, TransBuffer* buffer) |
| 41 | +{ |
| 42 | + failureSet_ = failureSet; |
| 43 | + buffer_ = buffer; |
| 44 | + backend_ = static_cast<Store*>((void*)config.backend); |
| 45 | + dumper_ = std::thread{&DumpQueue::BackendDumpStage, this}; |
| 46 | + std::promise<Status> started; |
| 47 | + auto fut = started.get_future(); |
| 48 | + dispatcher_ = std::thread{&DumpQueue::DispatchStage, this, config.deviceId, config.tensorSize, |
| 49 | + std::ref(started)}; |
| 50 | + return fut.get(); |
| 51 | +} |
| 52 | + |
| 53 | +void DumpQueue::Submit(TaskPtr task, WaiterPtr waiter) |
| 54 | +{ |
| 55 | + waiter->Up(); |
| 56 | + auto success = waiting_.TryPush({task, waiter}); |
| 57 | + if (success) { return; } |
| 58 | + UC_ERROR("Waiting queue full, submit dump task({}) failed.", task->id); |
| 59 | + failureSet_->Insert(task->id); |
| 60 | + waiter->Done(); |
| 61 | +} |
| 62 | + |
| 63 | +void DumpQueue::DispatchStage(int32_t deviceId, size_t tensorSize, std::promise<Status>& started) |
| 64 | +{ |
| 65 | + Trans::Device device; |
| 66 | + auto s = device.Setup(deviceId); |
| 67 | + if (s.Failure()) [[unlikely]] { |
| 68 | + UC_ERROR("Failed({}) to setup device({}).", s, deviceId); |
| 69 | + started.set_value(s); |
| 70 | + return; |
| 71 | + } |
| 72 | + auto stream = device.MakeStream(); |
| 73 | + if (!stream) [[unlikely]] { |
| 74 | + UC_ERROR("Failed to make stream on device({}).", deviceId); |
| 75 | + started.set_value(Status::Error()); |
| 76 | + return; |
| 77 | + } |
| 78 | + started.set_value(Status::OK()); |
| 79 | + size_t spinCount = 0; |
| 80 | + TaskPair pair; |
| 81 | + while (!stop_.load(std::memory_order_acquire)) { |
| 82 | + if (waiting_.TryPop(pair)) { |
| 83 | + spinCount = 0; |
| 84 | + DispatchOneTask(stream.get(), tensorSize, std::move(pair.first), |
| 85 | + std::move(pair.second)); |
| 86 | + } else { |
| 87 | + if (++spinCount < spinThreshold) { |
| 88 | + std::this_thread::yield(); |
| 89 | + } else { |
| 90 | + std::this_thread::sleep_for(tryPopInterval); |
| 91 | + spinCount = 0; |
| 92 | + } |
| 93 | + } |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +void DumpQueue::DispatchOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task, |
| 98 | + WaiterPtr waiter) |
| 99 | +{ |
| 100 | + if (!failureSet_->Contains(task->id)) { |
| 101 | + auto s = DumpOneTask(stream, tensorSize, task); |
| 102 | + if (s.Failure()) [[unlikely]] { failureSet_->Insert(task->id); } |
| 103 | + } |
| 104 | + waiter->Done(); |
| 105 | +} |
| 106 | + |
| 107 | +Status DumpQueue::DumpOneTask(Trans::Stream* stream, size_t tensorSize, TaskPtr task) |
| 108 | +{ |
| 109 | + Detail::TaskDesc backendTaskDesc; |
| 110 | + backendTaskDesc.brief = "Cache2Backend"; |
| 111 | + const auto nShard = task->desc.size(); |
| 112 | + std::vector<size_t> backendTaskIndex; |
| 113 | + backendTaskIndex.reserve(nShard); |
| 114 | + std::vector<ShardTask> shardTasks(nShard); |
| 115 | + for (size_t i = 0; i < task->desc.size(); i++) { |
| 116 | + auto& shard = task->desc[i]; |
| 117 | + auto& shardTask = shardTasks[i]; |
| 118 | + shardTask.bufferHandle = buffer_->Get(shard.owner, shard.index); |
| 119 | + if (!shardTask.bufferHandle.Owner()) { continue; } |
| 120 | + if (!shardTask.bufferHandle.Ready()) { |
| 121 | + auto s = stream->DeviceToHostAsync(shard.addrs.data(), shardTask.bufferHandle.Data(), |
| 122 | + tensorSize, shard.addrs.size()); |
| 123 | + if (s.Failure()) [[unlikely]] { |
| 124 | + UC_ERROR("Failed({}) to do D2H({}) batch({}) async.", s, tensorSize, |
| 125 | + shard.addrs.size()); |
| 126 | + return s; |
| 127 | + } |
| 128 | + } |
| 129 | + backendTaskDesc.push_back( |
| 130 | + Detail::Shard{shard.owner, shard.index, {shardTask.bufferHandle.Data()}}); |
| 131 | + backendTaskIndex.emplace_back(i); |
| 132 | + } |
| 133 | + if (backendTaskIndex.empty()) { return Status::OK(); } |
| 134 | + auto s = stream->Synchronized(); |
| 135 | + if (s.Failure()) [[unlikely]] { |
| 136 | + UC_ERROR("Failed({}) to sync on stream.", s); |
| 137 | + return s; |
| 138 | + } |
| 139 | + for (const auto& i : backendTaskIndex) { shardTasks[i].bufferHandle.MarkReady(); } |
| 140 | + auto res = backend_->Dump(std::move(backendTaskDesc)); |
| 141 | + if (!res) [[unlikely]] { |
| 142 | + UC_ERROR("Failed({}) to submit dump task to backend.", res.Error()); |
| 143 | + return res.Error(); |
| 144 | + } |
| 145 | + for (const auto& i : backendTaskIndex) { |
| 146 | + auto& shardTask = shardTasks[i]; |
| 147 | + shardTask.backendTaskHandle = res.Value(); |
| 148 | + dumping_.Push(std::move(shardTask)); |
| 149 | + } |
| 150 | + return Status::OK(); |
| 151 | +} |
| 152 | + |
| 153 | +void DumpQueue::BackendDumpStage() |
| 154 | +{ |
| 155 | + size_t spinCount = 0; |
| 156 | + ShardTask task; |
| 157 | + while (!stop_.load(std::memory_order_acquire)) { |
| 158 | + if (dumping_.TryPop(task)) { |
| 159 | + spinCount = 0; |
| 160 | + HandleOneShardTask(task); |
| 161 | + } else { |
| 162 | + if (++spinCount < spinThreshold) { |
| 163 | + std::this_thread::yield(); |
| 164 | + } else { |
| 165 | + std::this_thread::sleep_for(tryPopInterval); |
| 166 | + spinCount = 0; |
| 167 | + } |
| 168 | + } |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +void DumpQueue::HandleOneShardTask(ShardTask& task) |
| 173 | +{ |
| 174 | + static Detail::TaskHandle finishedBackendTaskHandle = 0; |
| 175 | + if (task.backendTaskHandle > finishedBackendTaskHandle) { |
| 176 | + auto s = backend_->Wait(task.backendTaskHandle); |
| 177 | + if (s.Failure()) { |
| 178 | + UC_ERROR("Failed({}) to wait backend task({}).", s, task.backendTaskHandle); |
| 179 | + return; |
| 180 | + } |
| 181 | + finishedBackendTaskHandle = task.backendTaskHandle; |
| 182 | + } |
| 183 | +} |
| 184 | + |
| 185 | +} // namespace UC::CacheStore |
0 commit comments