Skip to content

Commit a523479

Browse files
committed
Add a test for the fp8 kv cache
1 parent 8981f5a commit a523479

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama)
227227
llama_build_and_test(test-alloc.cpp)
228228
target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
229229

230+
# FP8 KV DSMLA roundtrip test for DeepSeek V3.2
231+
llama_build_and_test(test-fp8-kv-dsmla.cpp)
230232

231233
# Unit test for sparse_attn_indexer::idx_compute_scores_tile
232234
llama_build_and_test(test-indexer-scores-tile.cpp)

tests/test-fp8-kv-dsmla.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include "../src/llama-kv-cache-fp8.h"
2+
#include "../src/llama-model.h"
3+
#include "../src/llama-impl.h"
4+
5+
#include <ggml-alloc.h>
6+
#include <ggml-cpp.h>
7+
#include <ggml.h>
8+
9+
#include <cassert>
10+
#include <cmath>
11+
#include <cstdio>
12+
#include <vector>
13+
14+
// Simple unit test that exercises the DeepSeek V3.2 FP8 KV K blob
15+
// layout (fp8_ds_mla-style 656-byte entries) by round-tripping
16+
// synthetic latent + RoPE data through llama_kv_cache_fp8::cpy_k
17+
// and llama_kv_cache_fp8::get_k.
18+
19+
static void test_fp8_kv_dsmla_roundtrip() {
20+
printf("[fp8-kv-dsmla] starting roundtrip test...\n");
21+
fflush(stdout);
22+
23+
// Minimal hparams: 1 layer with KV, DeepSeek3.2 arch, kv_lora_rank=512, rope_dim=64
24+
llama_model_params mparams = llama_model_default_params();
25+
llama_model * model = new llama_model(mparams);
26+
model->arch = LLM_ARCH_DEEPSEEK3_2;
27+
28+
llama_hparams & hp = model->hparams;
29+
hp.n_layer = 1;
30+
hp.n_layer_kv_from_start = 1; // has_kv(0) == true
31+
hp.n_lora_kv = 512; // kv_lora_rank
32+
hp.n_rot = 64; // rope_dim
33+
hp.n_embd = 576; // not used here directly
34+
35+
// Ensure layers vector has at least 1 entry
36+
model->layers.resize(1);
37+
38+
const uint32_t kv_size = 4; // a few KV cells per stream
39+
const uint32_t n_seq_max = 1; // single stream
40+
const uint32_t n_pad = 1;
41+
const uint32_t n_swa = 0;
42+
43+
// Construct an FP8 KV cache instance
44+
llama_kv_cache_fp8 kv_fp8(
45+
*model,
46+
GGML_TYPE_F16, // ignored for DeepSeek V3.2 path
47+
GGML_TYPE_F16, // ignored for DeepSeek V3.2 path
48+
/*v_trans*/ true,
49+
/*offload*/ false,
50+
/*unified*/ true,
51+
kv_size,
52+
n_seq_max,
53+
n_pad,
54+
n_swa,
55+
LLAMA_SWA_TYPE_NONE,
56+
/*filter*/ nullptr,
57+
/*reuse*/ nullptr);
58+
59+
// Synthetic latent+RoPE per token
60+
const int64_t D_latent = 512;
61+
const int64_t D_rope = 64;
62+
const int64_t D_total = D_latent + D_rope;
63+
64+
const int64_t n_tokens = 3; // write 3 tokens into first 3 KV cells
65+
66+
// Build a ggml context for tensors
67+
ggml_init_params params = {};
68+
params.mem_size = 16 * 1024 * 1024;
69+
params.mem_buffer = nullptr;
70+
params.no_alloc = false;
71+
ggml_context * ctx = ggml_init(params);
72+
GGML_ASSERT(ctx != nullptr);
73+
74+
// k_cur: [D_total, 1, n_tokens]
75+
ggml_tensor * k_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D_total, 1, n_tokens);
76+
77+
// Fill with a simple deterministic pattern
78+
float * k_data = (float *) k_cur->data;
79+
for (int64_t t = 0; t < n_tokens; ++t) {
80+
for (int64_t d = 0; d < D_total; ++d) {
81+
float base = 0.01f * float(t + 1);
82+
// keep magnitudes reasonable for fp8 quantization
83+
k_data[d + D_total * t] = base * (1.0f + 0.001f * float(d));
84+
}
85+
}
86+
87+
// k_idxs: global KV indices for each token, here 0,1,2
88+
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
89+
int64_t * idx_data = (int64_t *) k_idxs->data;
90+
for (int64_t t = 0; t < n_tokens; ++t) {
91+
idx_data[t] = t; // stream=0, cell=t
92+
}
93+
94+
// Build a minimal slot_info that maps a single stream 0
95+
llama_kv_cache::slot_info sinfo;
96+
sinfo.s0 = 0;
97+
sinfo.s1 = 0;
98+
sinfo.strm = { 0 };
99+
sinfo.idxs = { std::vector<uint32_t>(kv_size) };
100+
for (uint32_t i = 0; i < kv_size; ++i) {
101+
sinfo.idxs[0][i] = i;
102+
}
103+
104+
// Write into the FP8 K blob using the new DS-MLA cpy_k
105+
kv_fp8.cpy_k(ctx, k_cur, k_idxs, /*il=*/0, sinfo);
106+
107+
// Read back using get_k: expect [576, 1, kv_size, ns=1]
108+
ggml_tensor * k_out = kv_fp8.get_k(ctx, /*il=*/0, kv_size, sinfo);
109+
GGML_ASSERT(k_out != nullptr);
110+
GGML_ASSERT(k_out->type == GGML_TYPE_F32);
111+
GGML_ASSERT(k_out->ne[0] == D_total);
112+
GGML_ASSERT(k_out->ne[1] == 1);
113+
GGML_ASSERT(k_out->ne[2] == kv_size);
114+
GGML_ASSERT(k_out->ne[3] == 1);
115+
116+
const float * out_data = (const float *) k_out->data;
117+
118+
// Compare only the first n_tokens cells; the rest are unspecified
119+
float max_abs_err = 0.0f;
120+
for (int64_t t = 0; t < n_tokens; ++t) {
121+
for (int64_t d = 0; d < D_total; ++d) {
122+
float orig = k_data[d + D_total * t];
123+
float got = out_data[d + D_total * (t + kv_size * 0)];
124+
float err = fabsf(orig - got);
125+
if (err > max_abs_err) max_abs_err = err;
126+
}
127+
}
128+
129+
printf("[fp8-kv-dsmla] max_abs_err = %g\n", (double) max_abs_err);
130+
fflush(stdout);
131+
132+
// FP8 + BF16 round-trip is lossy; allow a modest tolerance
133+
GGML_ASSERT(max_abs_err < 0.1f);
134+
135+
ggml_free(ctx);
136+
delete model;
137+
138+
printf("[fp8-kv-dsmla] roundtrip test PASSED\n");
139+
fflush(stdout);
140+
}
141+
142+
int main() {
143+
test_fp8_kv_dsmla_roundtrip();
144+
return 0;
145+
}

0 commit comments

Comments
 (0)