@@ -60,53 +60,7 @@ ggml_tensor * sparse_attn_indexer::idx_compute_scores_tile(
6060 t0_us = ggml_time_us ();
6161 }
6262
63- ggml_tensor * scores_acc = nullptr ;
64- long HEAD_CHUNK = H;
65-
66- if (const char *env = getenv (" LLAMA_SPARSE_TOPK_HEAD_CHUNK" )) {
67- long v = strtol (env, nullptr , 10 );
68- if (v > 0 ) HEAD_CHUNK = v;
69- }
70- if (HEAD_CHUNK > (long )H) HEAD_CHUNK = (long )H;
71- if (HEAD_CHUNK < 1 ) HEAD_CHUNK = 1 ;
72-
73- ggml_tensor * w_slice = ggml_view_2d (ctx, weights, H, Tc, weights->nb [1 ], t0*weights->nb [1 ]);
74-
75- for (int64_t h0 = 0 ; h0 < H; h0 += HEAD_CHUNK) {
76- int64_t ch = std::min<int64_t >(HEAD_CHUNK, H - h0);
77- size_t q_off_head = (size_t )t0 * q3d->nb [1 ] + (size_t )h0 * q3d->nb [2 ];
78- ggml_tensor * q_chunk_3d = ggml_view_3d (ctx, q3d, D, Tc, ch, q3d->nb [1 ], q3d->nb [2 ], q_off_head);
79- // permute [D, Tc, ch] -> [D, ch, Tc] so that head index is contiguous inside token-major layout
80- ggml_tensor * q_chunk_ht = ggml_permute (ctx, q_chunk_3d, 0 , 2 , 1 , 3 ); // [D, ch, Tc]
81- q_chunk_ht = ggml_cont (ctx, q_chunk_ht);
82- ggml_tensor * q_chunk_2d = ggml_reshape_2d (ctx, q_chunk_ht, D, Tc*ch);
83- ggml_tensor * b_q = q_chunk_2d;
84- if (use_fp16 && q_chunk_2d->type != GGML_TYPE_F16) {
85- b_q = ggml_cast (ctx, q_chunk_2d, GGML_TYPE_F16);
86- b_q = ggml_cont (ctx, b_q);
87- }
88- ggml_tensor * k_slice = ggml_view_2d (ctx, a_k, D, kv_end, a_k->nb [1 ], 0 );
89- ggml_tensor * logits_chunk = ggml_mul_mat (ctx, k_slice, b_q); // [kv_end, Tc*ch]
90- logits_chunk = ggml_cont (ctx, logits_chunk);
91- ggml_tensor * logits_chunk_3d = ggml_reshape_3d (ctx, logits_chunk, kv_end, ch, Tc); // [kv_end, ch, Tc]
92- logits_chunk_3d = ggml_relu (ctx, logits_chunk_3d);
93- size_t w_off_chunk = (size_t )h0 * w_slice->nb [0 ];
94- ggml_tensor * w_sub_2d = ggml_view_2d (ctx, w_slice, ch, Tc, w_slice->nb [1 ], w_off_chunk); // [ch, Tc]
95- w_sub_2d = ggml_cont (ctx, w_sub_2d);
96- ggml_tensor * w_sub_3d = ggml_reshape_3d (ctx, w_sub_2d, ch, 1 , Tc); // [ch,1,Tc]
97- ggml_tensor * log_p = ggml_permute (ctx, logits_chunk_3d, 1 , 0 , 2 , 3 ); // [ch, N_kv, Tc]
98- log_p = ggml_cont (ctx, log_p);
99- ggml_tensor * w_bc = ggml_repeat (ctx, w_sub_3d, log_p); // [ch, N_kv, Tc]
100- w_bc = ggml_cont (ctx, w_bc);
101- ggml_tensor * prod = ggml_mul (ctx, log_p, w_bc); // [ch, N_kv, Tc]
102- ggml_tensor * sum_ch= ggml_sum_rows (ctx, prod); // [1, N_kv, Tc]
103- // sum_ch is [1, N_kv, Tc] with linear layout kv-major, then t; reshape directly to [kv_end, Tc]
104- ggml_tensor * scores_chunk = ggml_reshape_2d (ctx, sum_ch, kv_end, Tc); // [kv_end, Tc]
105- scores_acc = scores_acc ? ggml_add (ctx, scores_acc, scores_chunk) : scores_chunk;
106- }
107- ggml_tensor * scores_tc = scores_acc;
108-
109- // CPU reference for TL FP8 lightning indexer, using GGML FP8 encode/decode.
63+ // CPU FP8 Lightning Indexer reference, using GGML FP8 helpers.
11064 // Layout conventions:
11165 // q3d : [D, T_total, H]
11266 // a_k : [D, N_kv]
@@ -212,7 +166,7 @@ ggml_tensor * sparse_attn_indexer::idx_compute_scores_tile(
212166 }
213167
214168 // Materialize scores_tc as a new F32 tensor [kv_end, Tc]
215- scores_tc = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, kv, Tc);
169+ ggml_tensor * scores_tc = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, kv, Tc);
216170 std::memcpy (scores_tc->data , out.data (), out.size () * sizeof (float ));
217171 scores_tc->op = GGML_OP_NONE;
218172
0 commit comments