Skip to content

Commit 589b957

Browse files
committed
bugfix in load queue while backend failed
1 parent 9160d95 commit 589b957

File tree

3 files changed

+130
-32
lines changed

3 files changed

+130
-32
lines changed

ucm/store/cache/cc/load_queue.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void LoadQueue::DispatchOneTask(TaskPtr task, WaiterPtr waiter)
102102
}
103103
shardTask.taskHandle = task->id;
104104
shardTask.shard = std::move(shard);
105-
shardTask.waiter = (i + 1 < nShard) ? nullptr : std::move(waiter);
105+
shardTask.waiter = (i + 1 < nShard) ? nullptr : waiter;
106106
}
107107
if (!backendTaskDesc.empty()) {
108108
auto res = backend_->Load(std::move(backendTaskDesc));

ucm/store/test/case/cache/cache_load_queue_test.cc

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,27 @@
2424
#include <gtest/gtest.h>
2525
#include "cache/cc/load_queue.h"
2626
#include "detail/data_generator.h"
27+
#include "detail/mock_store.h"
2728
#include "detail/random.h"
2829
#include "detail/types_helper.h"
2930

3031
class UCCacheLoadQueueTest : public testing::Test {
31-
public:
32-
class StoreStub : public UC::Store {
33-
private:
34-
static UC::Detail::TaskHandle NextId()
35-
{
36-
static std::atomic<size_t> id{1};
37-
return id.fetch_add(1, std::memory_order_relaxed);
38-
}
39-
40-
public:
41-
UC::Expected<std::vector<uint8_t>> Lookup(const UC::Detail::BlockId* blocks,
42-
size_t num) override
43-
{
44-
std::vector<uint8_t> founds(num);
45-
for (size_t i = 0; i < num; i++) { founds[i] = true; }
46-
return founds;
47-
}
48-
void Prefetch(const UC::Detail::BlockId* blocks, size_t num) override {}
49-
UC::Expected<UC::Detail::TaskHandle> Load(UC::Detail::TaskDesc task) override
50-
{
51-
return NextId();
52-
}
53-
UC::Expected<UC::Detail::TaskHandle> Dump(UC::Detail::TaskDesc task) override
54-
{
55-
return NextId();
56-
}
57-
UC::Expected<bool> Check(UC::Detail::TaskHandle taskId) override { return true; }
58-
UC::Status Wait(UC::Detail::TaskHandle taskId) override { return UC::Status::OK(); }
59-
};
60-
6132
public:
6233
UC::Test::Detail::Random rd;
34+
static UC::Detail::TaskHandle NextId()
35+
{
36+
static std::atomic<size_t> id{1};
37+
return id.fetch_add(1, std::memory_order_relaxed);
38+
}
6339
};
6440

6541
TEST_F(UCCacheLoadQueueTest, LoadSameBlockTwice)
6642
{
6743
using namespace UC::CacheStore;
44+
UC::Test::Detail::MockStore backend;
45+
EXPECT_CALL(backend, Load).WillOnce(testing::Invoke(NextId));
46+
EXPECT_CALL(backend, Wait).WillOnce(testing::Return(UC::Status::OK()));
6847
UC::HashSet<UC::Detail::TaskHandle> failureSet;
69-
UCCacheLoadQueueTest::StoreStub backend;
7048
Config config;
7149
config.backend = (uintptr_t)(void*)&backend;
7250
config.tensorSize = 32768;
@@ -100,3 +78,76 @@ TEST_F(UCCacheLoadQueueTest, LoadSameBlockTwice)
10078
waiter2->Wait();
10179
ASSERT_FALSE(failureSet.Contains(task2->id));
10280
}
81+
82+
TEST_F(UCCacheLoadQueueTest, LoadWhileBackendSubmitFailed)
83+
{
84+
using namespace UC::CacheStore;
85+
using namespace testing;
86+
UC::Test::Detail::MockStore backend;
87+
EXPECT_CALL(backend, Load).WillOnce(testing::Return(UC::Status::Error()));
88+
UC::HashSet<UC::Detail::TaskHandle> failureSet;
89+
Config config;
90+
config.backend = (uintptr_t)(void*)&backend;
91+
config.tensorSize = 32768;
92+
config.shardSize = config.tensorSize;
93+
config.blockSize = config.shardSize;
94+
config.deviceId = 0;
95+
config.bufferSize = config.blockSize * 2048;
96+
config.engineId = rd.RandomString(10);
97+
config.shareBufferEnable = true;
98+
TransBuffer buffer;
99+
LoadQueue loadQ;
100+
auto s = buffer.Setup(config);
101+
ASSERT_EQ(s, UC::Status::OK());
102+
s = loadQ.Setup(config, &failureSet, &buffer);
103+
ASSERT_EQ(s, UC::Status::OK());
104+
auto blockId = UC::Test::Detail::TypesHelper::MakeBlockId("a1b2c3d4e5f6789012345678901234ab");
105+
constexpr size_t shardIdx = 0;
106+
UC::Test::Detail::DataGenerator data{1, config.blockSize};
107+
data.Generate();
108+
UC::Detail::TaskDesc desc{
109+
{blockId, shardIdx, {data.Buffer()}}
110+
};
111+
auto task = std::make_shared<TransTask>(TransTask::Type::LOAD, desc);
112+
auto waiter = std::make_shared<UC::Latch>();
113+
loadQ.Submit(task, waiter);
114+
waiter->Wait();
115+
ASSERT_TRUE(failureSet.Contains(task->id));
116+
}
117+
118+
TEST_F(UCCacheLoadQueueTest, LoadWhileBackendWaitFailed)
119+
{
120+
using namespace UC::CacheStore;
121+
using namespace testing;
122+
UC::Test::Detail::MockStore backend;
123+
EXPECT_CALL(backend, Load).WillOnce(testing::Invoke(NextId));
124+
EXPECT_CALL(backend, Wait).WillOnce(testing::Return(UC::Status::Error()));
125+
UC::HashSet<UC::Detail::TaskHandle> failureSet;
126+
Config config;
127+
config.backend = (uintptr_t)(void*)&backend;
128+
config.tensorSize = 32768;
129+
config.shardSize = config.tensorSize;
130+
config.blockSize = config.shardSize;
131+
config.deviceId = 0;
132+
config.bufferSize = config.blockSize * 2048;
133+
config.engineId = rd.RandomString(10);
134+
config.shareBufferEnable = true;
135+
TransBuffer buffer;
136+
LoadQueue loadQ;
137+
auto s = buffer.Setup(config);
138+
ASSERT_EQ(s, UC::Status::OK());
139+
s = loadQ.Setup(config, &failureSet, &buffer);
140+
ASSERT_EQ(s, UC::Status::OK());
141+
auto blockId = UC::Test::Detail::TypesHelper::MakeBlockId("a1b2c3d4e5f6789012345678901234ab");
142+
constexpr size_t shardIdx = 0;
143+
UC::Test::Detail::DataGenerator data{1, config.blockSize};
144+
data.Generate();
145+
UC::Detail::TaskDesc desc{
146+
{blockId, shardIdx, {data.Buffer()}}
147+
};
148+
auto task = std::make_shared<TransTask>(TransTask::Type::LOAD, desc);
149+
auto waiter = std::make_shared<UC::Latch>();
150+
loadQ.Submit(task, waiter);
151+
waiter->Wait();
152+
ASSERT_TRUE(failureSet.Contains(task->id));
153+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* MIT License
3+
*
4+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
* */
24+
#ifndef UNIFIEDCACHE_TEST_MOCK_STORE_H
25+
#define UNIFIEDCACHE_TEST_MOCK_STORE_H
26+
27+
#include <gmock/gmock.h>
28+
#include "ucmstore.h"
29+
30+
namespace UC::Test::Detail {
31+
32+
class MockStore : public UC::Store {
33+
public:
34+
MOCK_METHOD((UC::Expected<std::vector<uint8_t>>), Lookup,
35+
(const UC::Detail::BlockId* blocks, size_t num), (override));
36+
MOCK_METHOD(void, Prefetch, (const UC::Detail::BlockId* blocks, size_t num), (override));
37+
MOCK_METHOD((UC::Expected<UC::Detail::TaskHandle>), Load, (UC::Detail::TaskDesc task),
38+
(override));
39+
MOCK_METHOD((UC::Expected<UC::Detail::TaskHandle>), Dump, (UC::Detail::TaskDesc task),
40+
(override));
41+
MOCK_METHOD((UC::Expected<bool>), Check, (UC::Detail::TaskHandle taskId), (override));
42+
MOCK_METHOD((UC::Status), Wait, (UC::Detail::TaskHandle taskId), (override));
43+
};
44+
45+
} // namespace UC::Test::Detail
46+
47+
#endif

0 commit comments

Comments
 (0)