Skip to content

Commit 8715063

Browse files
committed
I don't have a benchmark from before we added CPU FP8, but this change
restores a fair bit (if not all) of the lost performance: - The inner-most `d` loop does only `dot += qv[d] * kvp[d];` - All FP8 work has been hoisted into the Qq/Kh precomputation loops, which are O(D * H * Tc + D * kv) instead of O(D * H * Tc * kv).
1 parent d495b26 commit 8715063

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

src/llama-sparse-indexer.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ ggml_tensor * sparse_attn_indexer::idx_compute_scores_tile(
9696
}
9797
}
9898

99+
// Precompute FP8-dequantized Q: Qq = dequant(quant(Q))
100+
std::vector<float> Qq(Q.size());
101+
for (int64_t tc = 0; tc < Tc; ++tc) {
102+
for (int64_t h = 0; h < H; ++h) {
103+
for (int64_t d = 0; d < D; ++d) {
104+
size_t idx_q = (size_t)d + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h);
105+
Qq[idx_q] = f32_to_e4m3_to_f32(Q[idx_q]);
106+
}
107+
}
108+
}
109+
110+
99111
// Pack weights [H, Tc] for this tile: W[h + H*tc]
100112
std::vector<float> W((size_t)H * (size_t)Tc);
101113
for (int64_t tc = 0; tc < Tc; ++tc) {
@@ -130,20 +142,31 @@ ggml_tensor * sparse_attn_indexer::idx_compute_scores_tile(
130142
K_sf[i] = maxv / 448.0f;
131143
}
132144

133-
// Compute FP8-like logits into host buffer
145+
// Precompute FP8-dequantized K with per-row scaling: Kh = dequant(quant(K / K_sf[row]))
146+
std::vector<float> Kh(K.size());
147+
for (int64_t i = 0; i < kv; ++i) {
148+
float sf = K_sf[i];
149+
const float *kvp = K.data() + (size_t)D * (size_t)i;
150+
float *khp = Kh.data() + (size_t)D * (size_t)i;
151+
for (int64_t d = 0; d < D; ++d) {
152+
float v = kvp[d] / sf;
153+
khp[d] = f32_to_e4m3_to_f32(v);
154+
}
155+
}
156+
157+
158+
// Compute FP8-like logits into host buffer using precomputed Qq and Kh
134159
std::vector<float> out((size_t)kv * (size_t)Tc, 0.0f);
135160
for (int64_t tc = 0; tc < Tc; ++tc) {
136161
for (int64_t i = 0; i < kv; ++i) {
137162
float acc = 0.0f;
138-
const float *kvp = K.data() + (size_t)D * (size_t)i;
163+
const float *kvp = Kh.data() + (size_t)D * (size_t)i;
139164
float sf_k = K_sf[i];
140165
for (int64_t h = 0; h < H; ++h) {
141-
const float *qv = Q.data() + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h);
166+
const float *qv = Qq.data() + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h);
142167
float dot = 0.0f;
143168
for (int64_t d = 0; d < D; ++d) {
144-
float qh = f32_to_e4m3_to_f32(qv[d]);
145-
float kh = f32_to_e4m3_to_f32(kvp[d] / sf_k);
146-
dot += qh * kh;
169+
dot += qv[d] * kvp[d];
147170
}
148171
if (dot < 0.0f) dot = 0.0f; // ReLU
149172
acc += dot * W[(size_t)h + (size_t)H * (size_t)tc];

0 commit comments

Comments
 (0)