Skip to content

Commit e952efa

Browse files
committed
Add the FP8 pack custom op hook and replacing the unsafe pointer‑based wiring.
1 parent 5112464 commit e952efa

File tree

3 files changed

+64
-38
lines changed

3 files changed

+64
-38
lines changed

src/llama-kv-cache-fp8.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,60 @@ bool llama_kv_cache_fp8::state_read_data(llama_io_read_i & io, uint32_t strm, ui
769769
return false;
770770
}
771771

772+
773+
774+
struct kv_dsmla_pack_userdata {
775+
int32_t il;
776+
int32_t kv_size;
777+
int32_t n_stream;
778+
};
779+
780+
static void kv_dsmla_pack_custom(ggml_tensor * dst, int ith, int nth, void * userdata) {
781+
GGML_UNUSED(dst);
782+
GGML_UNUSED(ith);
783+
GGML_UNUSED(nth);
784+
GGML_UNUSED(userdata);
785+
// CPU stub: real work is performed in CUDA backend via specialized handler.
786+
}
787+
788+
ggml_tensor * llama_kv_cache_fp8::build_k_pack_node(
789+
ggml_context * ctx,
790+
ggml_tensor * k_latent_rope,
791+
ggml_tensor * k_idxs,
792+
int32_t il) const {
793+
GGML_ASSERT(ctx != nullptr);
794+
GGML_ASSERT(k_latent_rope != nullptr);
795+
GGML_ASSERT(k_idxs != nullptr);
796+
GGML_ASSERT(k_idxs->type == GGML_TYPE_I64);
797+
798+
if (model.arch != LLM_ARCH_DEEPSEEK3_2) {
799+
return k_latent_rope;
800+
}
801+
802+
const kv_layer_fp8 * lyr = get_layer(il);
803+
if (lyr == nullptr || lyr->k_blob == nullptr) {
804+
return k_latent_rope;
805+
}
806+
807+
kv_dsmla_pack_userdata * ud = new kv_dsmla_pack_userdata;
808+
ud->il = il;
809+
ud->kv_size = (int32_t) get_size();
810+
ud->n_stream = (int32_t) get_n_stream();
811+
812+
ggml_tensor * args[3] = { k_latent_rope, k_idxs, lyr->k_blob };
813+
ggml_tensor * node = ggml_custom_4d(
814+
ctx,
815+
GGML_TYPE_F32,
816+
1, 1, 1, 1,
817+
args,
818+
3,
819+
kv_dsmla_pack_custom,
820+
GGML_N_TASKS_MAX,
821+
ud);
822+
823+
return node;
824+
}
825+
772826
// Accessors for K/V are left unimplemented for now since the FP8 cache
773827
// is not yet used in any graph. They will be filled in when wiring the
774828
// cache to DeepSeek V3.2.

src/llama-kv-cache-fp8.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class llama_kv_cache_fp8 : public llama_memory_i {
8080
void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
8181
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
8282

83+
ggml_tensor * build_k_pack_node(ggml_context * ctx, ggml_tensor * k_latent_rope, ggml_tensor * k_idxs, int32_t il) const;
84+
8385
private:
8486
const llama_model & model;
8587
const llama_hparams & hparams;

src/llama-model.cpp

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13904,27 +13904,12 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
1390413904
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il));
1390513905
ggml_build_forward_expand(gf, inp_attn->get_kq_mask());
1390613906

13907-
// Optional: DeepSeek V3.2 FP8 K-side KV cache write
13907+
// Optional: DeepSeek V3.2 FP8 K-side KV cache pack (custom op)
1390813908
if (model.kv_fp8_ds32) {
13909-
const int64_t D_latent = kv_lora_rank;
13910-
const int64_t D_rope = n_rot;
13911-
const int64_t D_total = D_latent + D_rope;
13912-
GGML_UNUSED(D_total);
1391313909
ggml_tensor * k_fp8_in = ggml_concat(ctx0, kv_cmpr, k_pe, 0); // [D_total,1,n_tokens]
13914-
13915-
llama_kv_cache::slot_info sinfo_fp8;
13916-
sinfo_fp8.s0 = 0;
13917-
sinfo_fp8.s1 = 0;
13918-
sinfo_fp8.strm = { 0 };
13919-
sinfo_fp8.idxs = { std::vector<uint32_t>(model.kv_fp8_ds32->get_size()) };
13920-
for (uint32_t i = 0; i < model.kv_fp8_ds32->get_size(); ++i) {
13921-
sinfo_fp8.idxs[0][i] = i;
13922-
}
13923-
13924-
ggml_tensor * k_idxs = inp_attn->get_k_idxs();
13925-
ggml_build_forward_expand(
13926-
gf,
13927-
model.kv_fp8_ds32->cpy_k(ctx0, k_fp8_in, k_idxs, il, sinfo_fp8));
13910+
ggml_tensor * k_idxs = inp_attn->get_k_idxs();
13911+
ggml_tensor * pack_node = model.kv_fp8_ds32->build_k_pack_node(ctx0, k_fp8_in, k_idxs, il);
13912+
ggml_build_forward_expand(gf, pack_node);
1392813913
}
1392913914
}
1393013915

@@ -14060,27 +14045,12 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
1406014045
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il));
1406114046
ggml_build_forward_expand(gf, inp_attn->get_kq_mask());
1406214047

14063-
// Optional: DeepSeek V3.2 FP8 K-side KV cache write
14048+
// Optional: DeepSeek V3.2 FP8 K-side KV cache pack (custom op)
1406414049
if (model.kv_fp8_ds32) {
14065-
const int64_t D_latent = kv_lora_rank;
14066-
const int64_t D_rope = n_rot;
14067-
const int64_t D_total = D_latent + D_rope;
14068-
GGML_UNUSED(D_total);
1406914050
ggml_tensor * k_fp8_in = ggml_concat(ctx0, kv_cmpr, k_pe, 0); // [D_total,1,n_tokens]
14070-
14071-
llama_kv_cache::slot_info sinfo_fp8;
14072-
sinfo_fp8.s0 = 0;
14073-
sinfo_fp8.s1 = 0;
14074-
sinfo_fp8.strm = { 0 };
14075-
sinfo_fp8.idxs = { std::vector<uint32_t>(model.kv_fp8_ds32->get_size()) };
14076-
for (uint32_t i = 0; i < model.kv_fp8_ds32->get_size(); ++i) {
14077-
sinfo_fp8.idxs[0][i] = i;
14078-
}
14079-
14080-
ggml_tensor * k_idxs = inp_attn->get_k_idxs();
14081-
ggml_build_forward_expand(
14082-
gf,
14083-
model.kv_fp8_ds32->cpy_k(ctx0, k_fp8_in, k_idxs, il, sinfo_fp8));
14051+
ggml_tensor * k_idxs = inp_attn->get_k_idxs();
14052+
ggml_tensor * pack_node = model.kv_fp8_ds32->build_k_pack_node(ctx0, k_fp8_in, k_idxs, il);
14053+
ggml_build_forward_expand(gf, pack_node);
1408414054
}
1408514055
}
1408614056

0 commit comments

Comments
 (0)