@@ -1659,17 +1659,46 @@ struct test_rwkv_wkv6 : public test_case {
16591659
16601660 ggml_tensor * build_graph (ggml_context * ctx) override {
16611661 const int64_t n_tokens = n_seq_tokens * n_seqs;
1662- ggml_tensor * r = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1663- ggml_tensor * k = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ head_size, 1 , head_count, n_tokens }.data ());
1664- ggml_tensor * v = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1662+ ggml_tensor * r = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1663+ ggml_tensor * k = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1664+ ggml_tensor * v = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
16651665 ggml_tensor * tf = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size, head_count }.data ());
1666- ggml_tensor * td = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1666+ ggml_tensor * td = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
16671667 ggml_tensor * s = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size * head_size * head_count, n_seqs }.data ());
16681668 ggml_tensor * out = ggml_rwkv_wkv6 (ctx, k, v, r, tf, td, s);
16691669 return out;
16701670 }
16711671};
16721672
1673+ // GGML_OP_GATED_LINEAR_ATTN
1674+ struct test_gla : public test_case {
1675+ const ggml_type type;
1676+
1677+ const int64_t head_count;
1678+ const int64_t head_size;
1679+ const int64_t n_seq_tokens;
1680+ const int64_t n_seqs;
1681+
1682+ std::string vars () override {
1683+ return VARS_TO_STR5 (type, head_count, head_size, n_seq_tokens, n_seqs);
1684+ }
1685+
1686+ test_gla (ggml_type type = GGML_TYPE_F32,
1687+ int64_t head_count = 32 , int64_t head_size = 64 , int64_t n_seq_tokens = 32 , int64_t n_seqs = 32 )
1688+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1689+
1690+ ggml_tensor * build_graph (ggml_context * ctx) override {
1691+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1692+ ggml_tensor * q = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1693+ ggml_tensor * k = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1694+ ggml_tensor * v = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1695+ ggml_tensor * g = ggml_new_tensor (ctx, type, 3 , std::vector<int64_t >{ head_size, head_count, n_tokens }.data ());
1696+ ggml_tensor * s = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size * head_size * head_count, n_seqs }.data ());
1697+ ggml_tensor * out = ggml_gated_linear_attn (ctx, k, v, q, g, s, pow (head_size, -0.5 ));
1698+ return out;
1699+ }
1700+ };
1701+
16731702// GGML_OP_MUL_MAT
16741703struct test_mul_mat : public test_case {
16751704 const ggml_type type_a;
@@ -3626,6 +3655,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
36263655 test_cases.emplace_back (new test_rwkv_wkv6 (GGML_TYPE_F32, 32 , 64 , 32 , 4 ));
36273656 test_cases.emplace_back (new test_rwkv_wkv6 (GGML_TYPE_F32, 32 , 64 , 128 , 4 ));
36283657
3658+ test_cases.emplace_back (new test_gla (GGML_TYPE_F32, 32 , 64 , 1 , 1 ));
3659+ test_cases.emplace_back (new test_gla (GGML_TYPE_F32, 32 , 64 , 32 , 1 ));
3660+ test_cases.emplace_back (new test_gla (GGML_TYPE_F32, 32 , 64 , 32 , 4 ));
3661+ test_cases.emplace_back (new test_gla (GGML_TYPE_F32, 32 , 64 , 128 , 4 ));
3662+
36293663 for (int i = 1 ; i < 9 ; ++i) {
36303664 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 16 , i, 256 , { 1 , 1 }, {1 , 1 }));
36313665 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_Q4_0, GGML_TYPE_F32, 16 , i, 256 , { 1 , 1 }, {1 , 1 }));
0 commit comments