Skip to content

Commit 8981f5a

Browse files
committed
Flesh out get_k
1 parent 513ea61 commit 8981f5a

File tree

2 files changed

+208
-10
lines changed

2 files changed

+208
-10
lines changed

src/llama-kv-cache-fp8.cpp

Lines changed: 205 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,102 @@ static void e4m3_to_fp32_row(const ggml_e4m3_t * src, float * dst, int64_t k) {
276276
ggml_e4m3_to_fp32_row(src, dst, k);
277277
}
278278

279+
280+
const llama_kv_cache_fp8::kv_layer_fp8 * llama_kv_cache_fp8::get_layer(int32_t il) const {
281+
auto it = map_layer_ids.find(il);
282+
if (it == map_layer_ids.end()) {
283+
return nullptr;
284+
}
285+
int32_t idx = it->second;
286+
GGML_ASSERT(idx >= 0 && (size_t) idx < layers.size());
287+
return &layers[idx];
288+
}
289+
290+
static void pack_fp8_ds_mla_entry(
291+
const float * latent, // [512]
292+
const float * rope, // [64]
293+
void * dst_bytes) {
294+
// Layout:
295+
// [0..511] : 512 x FP8 E4M3 codes
296+
// [512..527] : 4 x FP32 scales
297+
// [528..655] : 64 x BF16 RoPE
298+
299+
uint8_t * dst = (uint8_t *) dst_bytes;
300+
301+
const int kv_lora_rank = 512;
302+
const int rope_dim = 64;
303+
const int tile_size = 128;
304+
const int n_tiles = kv_lora_rank / tile_size; // 4
305+
306+
float tile_scales[n_tiles];
307+
308+
// Compute per-tile scales
309+
for (int t = 0; t < n_tiles; ++t) {
310+
float amax = 0.0f;
311+
const float * tile = latent + t * tile_size;
312+
for (int i = 0; i < tile_size; ++i) {
313+
float v = fabsf(tile[i]);
314+
if (v > amax) amax = v;
315+
}
316+
// match vLLM: scale ~ amax / 448, guard against tiny amax
317+
float scale = amax / 448.0f;
318+
if (scale < 1e-4f) scale = 1e-4f;
319+
tile_scales[t] = scale;
320+
}
321+
322+
// Write scales after latent codes: view as float[4]
323+
float * scale_dst = (float *)(dst + 512);
324+
for (int t = 0; t < n_tiles; ++t) {
325+
scale_dst[t] = tile_scales[t];
326+
}
327+
328+
// Quantize latent to FP8 per tile
329+
for (int t = 0; t < n_tiles; ++t) {
330+
float inv_scale = 1.0f / tile_scales[t];
331+
const float * tile = latent + t * tile_size;
332+
ggml_e4m3_t * codes = (ggml_e4m3_t *)(dst + t * tile_size);
333+
float tmp[tile_size];
334+
for (int i = 0; i < tile_size; ++i) {
335+
tmp[i] = tile[i] * inv_scale;
336+
}
337+
fp32_to_e4m3_row(tmp, codes, tile_size);
338+
}
339+
340+
// Pack RoPE tail as BF16 at offset 528
341+
ggml_bf16_t * rope_dst = (ggml_bf16_t *)(dst + 528);
342+
ggml_fp32_to_bf16_row_ref(rope, rope_dst, rope_dim);
343+
}
344+
345+
static void unpack_fp8_ds_mla_entry(
346+
const void * src_bytes,
347+
float * latent_out, // [512]
348+
float * rope_out) { // [64]
349+
const uint8_t * src = (const uint8_t *) src_bytes;
350+
351+
const int kv_lora_rank = 512;
352+
const int rope_dim = 64;
353+
const int tile_size = 128;
354+
const int n_tiles = kv_lora_rank / tile_size; // 4
355+
356+
const float * scale_src = (const float *)(src + 512);
357+
358+
// Dequantize latent
359+
for (int t = 0; t < n_tiles; ++t) {
360+
float scale = scale_src[t];
361+
const ggml_e4m3_t * codes = (const ggml_e4m3_t *)(src + t * tile_size);
362+
float tmp[tile_size];
363+
e4m3_to_fp32_row(codes, tmp, tile_size);
364+
float * tile_out = latent_out + t * tile_size;
365+
for (int i = 0; i < tile_size; ++i) {
366+
tile_out[i] = tmp[i] * scale;
367+
}
368+
}
369+
370+
// Unpack RoPE BF16 tail at offset 528
371+
const ggml_bf16_t * rope_src = (const ggml_bf16_t *)(src + 528);
372+
ggml_bf16_to_fp32_row(rope_src, rope_out, rope_dim);
373+
}
374+
279375
// Clear / seq_* / state_* follow the patterns of llama_kv_cache but
280376
// operate on v_cells/v_heads only. For brevity we reuse the same
281377
// logic by delegating where possible.
@@ -672,12 +768,63 @@ bool llama_kv_cache_fp8::state_read_data(llama_io_read_i & io, uint32_t strm, ui
672768
// is not yet used in any graph. They will be filled in when wiring the
673769
// cache to DeepSeek V3.2.
674770

771+
772+
675773
ggml_tensor * llama_kv_cache_fp8::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const llama_kv_cache::slot_info & sinfo) const {
676-
GGML_UNUSED(ctx);
677-
GGML_UNUSED(il);
774+
GGML_ASSERT(ctx != nullptr);
678775
GGML_UNUSED(n_kv);
679-
GGML_UNUSED(sinfo);
680-
return nullptr;
776+
777+
// Only support DeepSeek V3.2 fp8_ds_mla-style K blob for now.
778+
if (model.arch != LLM_ARCH_DEEPSEEK3_2) {
779+
return nullptr;
780+
}
781+
782+
const kv_layer_fp8 * lyr = get_layer(il);
783+
if (lyr == nullptr || lyr->k_blob == nullptr) {
784+
return nullptr;
785+
}
786+
787+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
788+
const uint32_t kv_size = get_size();
789+
790+
// We expose a simple [D=576, H=1, n_kv, ns] layout for now:
791+
// - 512 dims: dequantized latent
792+
// - 64 dims : dequantized RoPE
793+
const int64_t D_latent = 512;
794+
const int64_t D_rope = 64;
795+
const int64_t D_total = D_latent + D_rope;
796+
797+
ggml_tensor * out = ggml_new_tensor_4d(ctx, GGML_TYPE_F32,
798+
D_total, 1, kv_size, ns);
799+
800+
// For each stream and KV index, unpack the 656-byte entry.
801+
for (uint32_t s = 0; s < ns; ++s) {
802+
uint32_t strm = sinfo.strm[s];
803+
GGML_ASSERT(strm < lyr->k_blob->ne[2]);
804+
for (uint32_t idx = 0; idx < kv_size; ++idx) {
805+
// Compute byte offset into k_blob for (stream=strm, cell=idx)
806+
size_t off = (size_t) idx * lyr->k_blob->nb[1] + (size_t) strm * lyr->k_blob->nb[2];
807+
const uint8_t * src = (const uint8_t *) lyr->k_blob->data + off;
808+
809+
float latent[D_latent];
810+
float rope[D_rope];
811+
unpack_fp8_ds_mla_entry(src, latent, rope);
812+
813+
// Write into out: [D_total, 1, kv_size, ns]
814+
for (int64_t d = 0; d < D_latent; ++d) {
815+
((float *) out->data)[d
816+
+ D_total * (idx
817+
+ (size_t) kv_size * s)] = latent[d];
818+
}
819+
for (int64_t d = 0; d < D_rope; ++d) {
820+
((float *) out->data)[(D_latent + d)
821+
+ D_total * (idx
822+
+ (size_t) kv_size * s)] = rope[d];
823+
}
824+
}
825+
}
826+
827+
return out;
681828
}
682829

683830
ggml_tensor * llama_kv_cache_fp8::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const llama_kv_cache::slot_info & sinfo) const {
@@ -688,13 +835,61 @@ ggml_tensor * llama_kv_cache_fp8::get_v(ggml_context * ctx, int32_t il, uint32_t
688835
return nullptr;
689836
}
690837

838+
839+
691840
ggml_tensor * llama_kv_cache_fp8::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const llama_kv_cache::slot_info & sinfo) const {
692-
GGML_UNUSED(ctx);
693-
GGML_UNUSED(k_cur);
694-
GGML_UNUSED(k_idxs);
695-
GGML_UNUSED(il);
696-
GGML_UNUSED(sinfo);
697-
return nullptr;
841+
GGML_ASSERT(ctx != nullptr);
842+
GGML_ASSERT(k_cur != nullptr);
843+
GGML_ASSERT(k_idxs != nullptr);
844+
845+
// Only support DeepSeek V3.2 fp8_ds_mla-style K blob for now.
846+
if (model.arch != LLM_ARCH_DEEPSEEK3_2) {
847+
return nullptr;
848+
}
849+
850+
const kv_layer_fp8 * lyr = get_layer(il);
851+
if (lyr == nullptr || lyr->k_blob == nullptr) {
852+
return nullptr;
853+
}
854+
855+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
856+
const uint32_t kv_size = get_size();
857+
858+
// Expect k_cur layout [D_total, 1, n_tokens] where D_total=512+64.
859+
const int64_t D_total = k_cur->ne[0];
860+
const int64_t n_tokens = k_cur->ne[2];
861+
const int64_t D_latent = 512;
862+
const int64_t D_rope = 64;
863+
GGML_ASSERT(D_total == D_latent + D_rope);
864+
865+
GGML_ASSERT(k_idxs->type == GGML_TYPE_I64);
866+
const int64_t * idx_data = (const int64_t *) k_idxs->data;
867+
868+
for (int64_t t = 0; t < n_tokens; ++t) {
869+
int64_t global_idx = idx_data[t];
870+
GGML_ASSERT(global_idx >= 0 && global_idx < (int64_t) (kv_size * n_stream));
871+
uint32_t strm = (uint32_t) (global_idx / kv_size);
872+
uint32_t cell = (uint32_t) (global_idx % kv_size);
873+
874+
GGML_ASSERT(strm < lyr->k_blob->ne[2]);
875+
size_t off = (size_t) cell * lyr->k_blob->nb[1] + (size_t) strm * lyr->k_blob->nb[2];
876+
uint8_t * dst = (uint8_t *) lyr->k_blob->data + off;
877+
878+
float latent[D_latent];
879+
float rope[D_rope];
880+
for (int64_t d = 0; d < D_latent; ++d) {
881+
latent[d] = ((float *) k_cur->data)[d + D_total * t];
882+
}
883+
for (int64_t d = 0; d < D_rope; ++d) {
884+
rope[d] = ((float *) k_cur->data)[(D_latent + d) + D_total * t];
885+
}
886+
887+
pack_fp8_ds_mla_entry(latent, rope, dst);
888+
}
889+
890+
// cpy_k normally returns the ggml node representing the copy; here we
891+
// simply return k_cur to keep the graph valid for now.
892+
return k_cur;
698893
}
699894

700895
ggml_tensor * llama_kv_cache_fp8::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const llama_kv_cache::slot_info & sinfo) const {

src/llama-kv-cache-fp8.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,7 @@ class llama_kv_cache_fp8 : public llama_memory_i {
157157

158158
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
159159
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
160+
161+
const kv_layer_fp8 * get_layer(int32_t il) const;
162+
160163
};

0 commit comments

Comments
 (0)