Skip to content

Commit 00930e6

Browse files
committed
Fix wkv test & add gla test
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent aaa870e commit 00930e6

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

tests/test-backend-ops.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,17 +1659,46 @@ struct test_rwkv_wkv6 : public test_case {
16591659

16601660
ggml_tensor * build_graph(ggml_context * ctx) override {
16611661
const int64_t n_tokens = n_seq_tokens * n_seqs;
1662-
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1663-
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
1664-
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1662+
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1663+
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1664+
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
16651665
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
1666-
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1666+
ggml_tensor * td = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
16671667
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
16681668
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
16691669
return out;
16701670
}
16711671
};
16721672

1673+
// GGML_OP_GATED_LINEAR_ATTN
1674+
struct test_gla : public test_case {
1675+
const ggml_type type;
1676+
1677+
const int64_t head_count;
1678+
const int64_t head_size;
1679+
const int64_t n_seq_tokens;
1680+
const int64_t n_seqs;
1681+
1682+
std::string vars() override {
1683+
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1684+
}
1685+
1686+
test_gla(ggml_type type = GGML_TYPE_F32,
1687+
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1688+
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1689+
1690+
ggml_tensor * build_graph(ggml_context * ctx) override {
1691+
const int64_t n_tokens = n_seq_tokens * n_seqs;
1692+
ggml_tensor * q = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1693+
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1694+
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1695+
ggml_tensor * g = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1696+
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1697+
ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
1698+
return out;
1699+
}
1700+
};
1701+
16731702
// GGML_OP_MUL_MAT
16741703
struct test_mul_mat : public test_case {
16751704
const ggml_type type_a;
@@ -3626,6 +3655,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
36263655
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
36273656
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
36283657

3658+
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
3659+
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
3660+
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
3661+
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
3662+
36293663
for (int i = 1; i < 9; ++i) {
36303664
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
36313665
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)