Skip to content

Commit 78da439

Browse files
committed
FP8 K is inferring again.
1 parent 3f17d34 commit 78da439

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,6 +2359,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23592359
{
23602360
n_tasks = n_threads;
23612361
} break;
2362+
case GGML_OP_KV_DSMLA_PACK:
2363+
{
2364+
// trivial metadata op for FP8 KV; handled only on CUDA backend
2365+
n_tasks = 1;
2366+
} break;
23622367
case GGML_OP_NONE:
23632368
{
23642369
n_tasks = 1;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3803,6 +3803,19 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
38033803
if (op->ne[2] != 1 || op->ne[3] != 1) return false;
38043804
return ggml_is_contiguous(a);
38053805
} break;
3806+
case GGML_OP_KV_DSMLA_PACK:
3807+
{
3808+
const struct ggml_tensor * k_lr = op->src[0];
3809+
const struct ggml_tensor * k_idx = op->src[1];
3810+
const struct ggml_tensor * blob = op->src[2];
3811+
if (!k_lr || !k_idx || !blob) return false;
3812+
if (k_lr->type != GGML_TYPE_F32) return false;
3813+
if (k_idx->type != GGML_TYPE_I64) return false;
3814+
if (blob->type != GGML_TYPE_I8) return false;
3815+
if (!ggml_is_contiguous(k_lr)) return false;
3816+
if (k_lr->ne[1] != 1) return false;
3817+
return true;
3818+
} break;
38063819
case GGML_OP_GLU:
38073820
switch (ggml_get_glu_op(op)) {
38083821
case GGML_GLU_OP_REGLU:

src/llama-kv-cache-fp8.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "llama-kv-cache-fp8.h"
2+
#include "ggml-backend.h"
23

34
#include "llama-impl.h"
45
#include "llama-io.h"
@@ -813,6 +814,11 @@ ggml_tensor * llama_kv_cache_fp8::get_k(ggml_context * ctx, int32_t il, uint32_t
813814
if (lyr == nullptr || lyr->k_blob == nullptr) {
814815
return nullptr;
815816
}
817+
// If K blob is device-resident, we cannot safely dereference it on host during graph build.
818+
// In that case, skip FP8-derived K and let callers fall back to the float KV cache.
819+
if (lyr->k_blob->buffer && !ggml_backend_buffer_is_host(lyr->k_blob->buffer)) {
820+
return nullptr;
821+
}
816822

817823
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
818824
const uint32_t kv_size = get_size();
@@ -881,6 +887,11 @@ ggml_tensor * llama_kv_cache_fp8::cpy_k(ggml_context * ctx, ggml_tensor * k_cur,
881887
if (lyr == nullptr || lyr->k_blob == nullptr) {
882888
return nullptr;
883889
}
890+
// If K blob is device-resident, we cannot safely dereference it on host during graph build.
891+
// In that case, skip FP8-derived K and let callers fall back to the float KV cache.
892+
if (lyr->k_blob->buffer && !ggml_backend_buffer_is_host(lyr->k_blob->buffer)) {
893+
return nullptr;
894+
}
884895

885896
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
886897
const uint32_t kv_size = get_size();

0 commit comments

Comments
 (0)