From de93b991682fa46acabb813641f7f6b884ad5dd0 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Fri, 28 Nov 2025 19:31:22 +0000 Subject: [PATCH] Add PagedAttention support (experimental, CUDA only) Implement PagedAttention algorithm from for memory-efficient KV cache management. This feature reduces memory fragmentation by storing KV cache in fixed-size blocks (similar to virtual memory paging) and enables efficient memory sharing between sequences through copy-on-write semantics. The implementation is experimental and disabled by default. Enable with the --pagedattention flag Signed-off-by: Eric Curtin --- common/arg.cpp | 8 + common/common.cpp | 4 + common/common.h | 1 + ggml/include/ggml.h | 28 + ggml/src/ggml-cpu/ggml-cpu.c | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 12 + ggml/src/ggml-cuda/paged-attention-backend.cu | 223 ++++ .../src/ggml-cuda/paged-attention-backend.cuh | 5 + ggml/src/ggml-cuda/paged-attention-v1.cu | 348 ++++++ ggml/src/ggml-cuda/paged-attention-v2.cu | 527 ++++++++ ggml/src/ggml-cuda/paged-attention.cuh | 256 ++++ ggml/src/ggml-cuda/paged-cpy.cu | 151 +++ ggml/src/ggml-cuda/paged-cpy.cuh | 6 + ggml/src/ggml.c | 70 +- include/llama.h | 1 + src/CMakeLists.txt | 1 + src/llama-context.cpp | 25 +- src/llama-cparams.h | 1 + src/llama-graph.cpp | 212 +++- src/llama-graph.h | 53 + src/llama-impl.h | 1 + src/llama-kv-cache-paged.cpp | 1061 +++++++++++++++++ src/llama-kv-cache-paged.h | 245 ++++ src/llama-model.cpp | 16 + src/models/llama.cpp | 16 +- tools/server/server-context.cpp | 7 +- 26 files changed, 3265 insertions(+), 17 deletions(-) create mode 100644 ggml/src/ggml-cuda/paged-attention-backend.cu create mode 100644 ggml/src/ggml-cuda/paged-attention-backend.cuh create mode 100644 ggml/src/ggml-cuda/paged-attention-v1.cu create mode 100644 ggml/src/ggml-cuda/paged-attention-v2.cu create mode 100644 ggml/src/ggml-cuda/paged-attention.cuh create mode 100644 ggml/src/ggml-cuda/paged-cpy.cu create mode 100644 ggml/src/ggml-cuda/paged-cpy.cuh create mode 100644 src/llama-kv-cache-paged.cpp create mode 100644 src/llama-kv-cache-paged.h diff --git a/common/arg.cpp b/common/arg.cpp index 9f3c8a97546..a31c6b1c925 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1017,6 +1017,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"--pagedattention"}, + "enable PagedAttention for KV cache (experimental, requires CUDA)", + [](common_params & params) { + fprintf(stderr, "DEBUG: --pagedattention flag parsed, setting params.use_paged_attention = true\n"); + params.use_paged_attention = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 0d7fd9a9371..bf97b96cff3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1249,6 +1249,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { struct llama_context_params common_context_params_to_llama(const common_params & params) { auto cparams = llama_context_default_params(); + fprintf(stderr, "DEBUG common_context_params_to_llama: params.use_paged_attention = %s\n", + params.use_paged_attention ? "true" : "false"); + cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; cparams.n_batch = params.n_batch; @@ -1275,6 +1278,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; cparams.kv_unified = params.kv_unified; + cparams.use_paged_attention = params.use_paged_attention; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index 2f23d0baa83..055f9e61fd5 100644 --- a/common/common.h +++ b/common/common.h @@ -406,6 +406,7 @@ struct common_params { bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache + bool use_paged_attention = false; // enable PagedAttention (experimental, requires CUDA) bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool use_mmap = true; // use mmap for faster loads diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc7..1b46f0cc2d2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -537,6 +537,8 @@ extern "C" { GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_PAGED_ATTENTION, + GGML_OP_PAGED_CPY, GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, GGML_OP_WIN_PART, @@ -2312,6 +2314,32 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * sinks); + // PagedAttention (paged KV cache attention) + // q: [n_tokens, n_heads, head_size] + // k_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged) + // v_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged) + // block_tables: [n_seqs, max_blocks_per_seq] (int32) + // seq_lens: [n_seqs] (int32) + GGML_API struct ggml_tensor * ggml_paged_attention( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k_cache, + struct ggml_tensor * v_cache, + struct ggml_tensor * block_tables, + struct ggml_tensor * seq_lens, + int32_t block_size, + float scale); + + // Copy K/V data to paged cache blocks (similar to vLLM's reshape_and_cache) + // kv_cur: [head_size, n_heads, n_tokens] - K or V current + // kv_cache: [num_blocks, n_kv_heads, head_size, block_size] - paged K or V cache + // slot_idxs: [n_tokens] (int32) - cache slot index for each token + GGML_API struct ggml_tensor * ggml_paged_cpy( + struct ggml_context * ctx, + struct ggml_tensor * kv_cur, + struct ggml_tensor * kv_cache, + struct ggml_tensor * slot_idxs); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb03..70f323497ae 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2062,6 +2062,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { // nop } break; + case GGML_OP_PAGED_ATTENTION: + { + // nop (CUDA-only operation) + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fa7e1e13a71..2f756407e81 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -32,6 +32,8 @@ #include "ggml-cuda/opt-step-sgd.cuh" #include "ggml-cuda/out-prod.cuh" #include "ggml-cuda/pad.cuh" +#include "ggml-cuda/paged-attention-backend.cuh" +#include "ggml-cuda/paged-cpy.cuh" #include "ggml-cuda/pool2d.cuh" #include "ggml-cuda/quantize.cuh" #include "ggml-cuda/rope.cuh" @@ -2719,6 +2721,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_PAGED_ATTENTION: + ggml_cuda_op_paged_attention(ctx, dst); + break; + case GGML_OP_PAGED_CPY: + ggml_cuda_op_paged_cpy(ctx, dst); + break; case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; @@ -4564,6 +4572,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + case GGML_OP_PAGED_ATTENTION: + return ggml_cuda_can_paged_attention(op); + case GGML_OP_PAGED_CPY: + return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; default: diff --git a/ggml/src/ggml-cuda/paged-attention-backend.cu b/ggml/src/ggml-cuda/paged-attention-backend.cu new file mode 100644 index 00000000000..71e4c93d77b --- /dev/null +++ b/ggml/src/ggml-cuda/paged-attention-backend.cu @@ -0,0 +1,223 @@ +/** + * GGML CUDA Backend for PagedAttention + * + * This file provides the CUDA backend implementation for the GGML_OP_PAGED_ATTENTION operation. + * It bridges GGML's operation framework with the PagedAttention CUDA kernels. + * + * NOTE: PagedAttention is currently experimental and only supported on CUDA. + * MUSA support is disabled due to compiler compatibility issues. + */ + +// PagedAttention is not yet supported on MUSA +#ifndef GGML_USE_MUSA + +#include "common.cuh" +#include "paged-attention.cuh" +#include "paged-attention-backend.cuh" + +// Extract parameters from GGML tensor +static void ggml_cuda_op_paged_attention_get_params( + const ggml_tensor * dst, + float * scale, + int32_t * block_size) { + + const float * params = (const float *)dst->op_params; + *scale = params[0]; + *block_size = (int32_t)params[1]; +} + +// Main CUDA backend function for PagedAttention +void ggml_cuda_op_paged_attention( + ggml_backend_cuda_context & ctx, + ggml_tensor * dst) { + + const ggml_tensor * q = dst->src[0]; // query + const ggml_tensor * k_cache = dst->src[1]; // key cache (paged) + const ggml_tensor * v_cache = dst->src[2]; // value cache (paged) + const ggml_tensor * block_tables = dst->src[3]; // block tables + const ggml_tensor * seq_lens = dst->src[4]; // sequence lengths + const ggml_tensor * alibi_slopes = dst->src[5]; // optional ALiBi slopes (can be nullptr) + + // Extract parameters + float scale; + int32_t block_size; + ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size); + + // Get tensor dimensions + const int64_t head_size = q->ne[0]; + const int64_t n_heads = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t n_kv_heads = k_cache->ne[2]; + const int64_t num_blocks = k_cache->ne[0]; + + const int64_t max_blocks_per_seq = block_tables->ne[0]; + + // Validate tensor dimensions + GGML_ASSERT(n_tokens > 0 && "Number of query tokens must be positive"); + GGML_ASSERT(n_seqs > 0 && "Number of sequences must be positive"); + GGML_ASSERT(num_blocks > 0 && "Number of KV cache blocks must be positive"); + GGML_ASSERT(max_blocks_per_seq > 0 && "Max blocks per sequence must be positive"); + + // Validate that we have enough blocks available + // Note: This is a soft check - actual usage depends on sequence lengths + GGML_ASSERT(num_blocks >= max_blocks_per_seq && + "Total number of blocks should be >= max blocks per sequence"); + + // For PagedAttention, typically we have one query per sequence (decode mode) + // or multiple queries per sequence (prefill mode) + GGML_ASSERT(n_tokens <= n_seqs * 1024 && + "Number of tokens seems unusually large relative to batch size"); + + // Get pointers + void * out_ptr = dst->data; + const void * q_ptr = q->data; + const void * k_cache_ptr = k_cache->data; + const void * v_cache_ptr = v_cache->data; + const int32_t * block_tables_ptr = (const int32_t *)block_tables->data; + const int32_t * seq_lens_ptr = (const int32_t *)seq_lens->data; + + // Debug: Check for null pointers + GGML_ASSERT(out_ptr != nullptr && "Output pointer is null"); + GGML_ASSERT(q_ptr != nullptr && "Query pointer is null"); + GGML_ASSERT(k_cache_ptr != nullptr && "K cache pointer is null"); + GGML_ASSERT(v_cache_ptr != nullptr && "V cache pointer is null"); + GGML_ASSERT(block_tables_ptr != nullptr && "Block tables pointer is null"); + GGML_ASSERT(seq_lens_ptr != nullptr && "Sequence lengths pointer is null"); + + // Get ALiBi slopes pointer if provided + const float * alibi_slopes_ptr = nullptr; + if (alibi_slopes != nullptr) { + // ALiBi slopes should be a 1D tensor with one slope per attention head + GGML_ASSERT(alibi_slopes->type == GGML_TYPE_F32 && + "ALiBi slopes must be float32"); + GGML_ASSERT(alibi_slopes->ne[0] == n_heads && + "ALiBi slopes tensor must have one value per head"); + alibi_slopes_ptr = (const float *)alibi_slopes->data; + } + + // Calculate max sequence length (needed to decide V1 vs V2) + int max_seq_len = 0; + for (int i = 0; i < n_seqs; i++) { + if (seq_lens_ptr[i] > max_seq_len) { + max_seq_len = seq_lens_ptr[i]; + } + } + + // Get CUDA stream + cudaStream_t stream = ctx.stream(); + + // Decide whether to use V1 or V2 + const bool use_v1 = ggml_cuda_paged_attention::should_use_v1( + max_seq_len, n_seqs, n_heads); + + // Launch appropriate kernel + if (use_v1) { + ggml_cuda_paged_attention::paged_attention_v1_launcher( + out_ptr, + q_ptr, + k_cache_ptr, + v_cache_ptr, + n_seqs, + n_heads, + n_kv_heads, + head_size, + block_size, + max_blocks_per_seq, + block_tables_ptr, + seq_lens_ptr, + max_seq_len, + scale, + alibi_slopes_ptr, + q->type, + k_cache->type, + stream); + } else { + ggml_cuda_paged_attention::paged_attention_v2_launcher( + out_ptr, + q_ptr, + k_cache_ptr, + v_cache_ptr, + n_seqs, + n_heads, + n_kv_heads, + head_size, + block_size, + max_blocks_per_seq, + block_tables_ptr, + seq_lens_ptr, + max_seq_len, + scale, + alibi_slopes_ptr, + q->type, + k_cache->type, + ctx.pool(), + stream); + } + + // Check for errors + CUDA_CHECK(cudaGetLastError()); +} + +// Check if PagedAttention is supported for given configuration +bool ggml_cuda_can_paged_attention(const ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k_cache = dst->src[1]; + + // Check data types + if (q->type != GGML_TYPE_F16 && q->type != GGML_TYPE_F32) { + return false; + } + + if (k_cache->type != GGML_TYPE_F16 && k_cache->type != GGML_TYPE_F32) { + return false; + } + + // Check head size is supported + const int64_t head_size = q->ne[0]; + const int supported_head_sizes[] = {32, 64, 80, 96, 112, 120, 128, 192, 256}; + bool head_size_supported = false; + + for (int hs : supported_head_sizes) { + if (head_size == hs) { + head_size_supported = true; + break; + } + } + + if (!head_size_supported) { + return false; + } + + // Extract block size and check it's supported + float scale; + int32_t block_size; + ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size); + + if (block_size != 8 && block_size != 16 && block_size != 32) { + return false; + } + + return true; +} + +#else // GGML_USE_MUSA + +// Stub implementations for MUSA (PagedAttention not yet supported) +#include "common.cuh" + +void ggml_cuda_op_paged_attention( + ggml_backend_cuda_context & ctx, + ggml_tensor * dst) { + GGML_UNUSED(ctx); + GGML_UNUSED(dst); + GGML_ABORT("PagedAttention is not yet supported on MUSA"); +} + +bool ggml_cuda_supports_paged_attention(const ggml_tensor * dst) { + GGML_UNUSED(dst); + return false; +} + +#endif // GGML_USE_MUSA diff --git a/ggml/src/ggml-cuda/paged-attention-backend.cuh b/ggml/src/ggml-cuda/paged-attention-backend.cuh new file mode 100644 index 00000000000..8a04ee8c0d2 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-attention-backend.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_op_paged_attention(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_can_paged_attention(const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/paged-attention-v1.cu b/ggml/src/ggml-cuda/paged-attention-v1.cu new file mode 100644 index 00000000000..eeb96f0dde9 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-attention-v1.cu @@ -0,0 +1,348 @@ +// PagedAttention is not yet supported on MUSA +#ifndef GGML_USE_MUSA + +#include "paged-attention.cuh" +#include "common.cuh" + +/** + * PagedAttention CUDA Kernel Implementation + * + * Based on the PagedAttention implementation from vLLM: + * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cuh + * + * Copyright (c) 2023-2024 vLLM Project + * SPDX-License-Identifier: Apache-2.0 + * + * Adapted for GGML by llama.cpp contributors + */ + +namespace ggml_cuda_paged_attention { + +// +// Main PagedAttention V1 Kernel +// +// This kernel computes attention for one sequence and one head per thread block. +// It reads K/V from paged blocks based on the block table. +// + +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ q, + const cache_t* __restrict__ k_cache, + const cache_t* __restrict__ v_cache, + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, + const int* __restrict__ seq_lens, + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int thread_idx = threadIdx.x; + + const int seq_len = seq_lens[seq_idx]; + if (seq_len == 0) return; + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // Shared memory for logits and reduction + extern __shared__ char shared_mem[]; + float* logits = reinterpret_cast(shared_mem); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + __shared__ float red_smem[2 * NUM_WARPS]; + + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + // Get KV head index (for GQA/MQA) + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + + // ALiBi bias (if applicable) + const float alibi_slope = alibi_slopes ? alibi_slopes[head_idx] : 0.0f; + + // Query pointer for this sequence and head + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + // Step 2: Compute Q·K for all tokens + float qk_max = -FLT_MAX; + + for (int block_idx = warp_idx; block_idx < num_seq_blocks; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + // Load K vectors from this block and compute dot products + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int token_idx = block_idx * BLOCK_SIZE + i; + if (token_idx >= seq_len) break; + + // Compute Q·K dot product + // K cache layout: [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* k_ptr = k_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + i; // token position within block + + // Compute dot product between Q and K + // Each thread computes part of the dot product + float qk = 0.0f; + for (int elem_idx = thread_idx; elem_idx < HEAD_SIZE; elem_idx += NUM_THREADS) { + // Load K element for this token + // K is stored as [head_size, block_size], so offset is elem_idx * BLOCK_SIZE + const cache_t k_val = k_ptr[elem_idx * BLOCK_SIZE]; + + // Load Q element (from scalar_t array) + const scalar_t q_val = q_ptr[elem_idx]; + + // Accumulate dot product + qk += float(q_val) * float(k_val); + } + + // Reduce across all threads in the block + #pragma unroll + for (int mask = NUM_THREADS / 2; mask >= 1; mask /= 2) { + qk += SHFL_XOR_SYNC(qk, mask); + } + + // Apply scale + qk *= scale; + + // Add ALiBi bias if applicable + if (alibi_slope != 0.0f) { + qk += alibi_slope * (token_idx - seq_len + 1); + } + + // Store logit (only thread 0 writes after full reduction) + if (thread_idx == 0) { + logits[token_idx] = qk; + } + + qk_max = fmaxf(qk_max, qk); + } + } + + // Step 3: Warp-level reduction to find max logit + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // Block-level reduction + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, SHFL_XOR_SYNC(qk_max, mask)); + } + qk_max = SHFL_SYNC(qk_max, 0); + + // Step 4: Compute softmax + float exp_sum = 0.0f; + for (int i = thread_idx; i < seq_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Normalize + const float inv_sum = __fdividef(1.0f, exp_sum + 1e-6f); + for (int i = thread_idx; i < seq_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Step 5: Compute attention output (softmax · V) + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + float accs[NUM_ROWS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.0f; + } + + // Compute attention output by multiplying softmax weights with V + // V cache layout: [num_blocks, num_kv_heads, head_size, block_size] + for (int block_idx = warp_idx; block_idx < num_seq_blocks; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int token_idx = block_idx * BLOCK_SIZE + i; + if (token_idx >= seq_len) break; + + // Get attention weight for this token + const float attn_weight = logits[token_idx]; + + // Accumulate V vectors weighted by attention + #pragma unroll + for (int j = 0; j < NUM_ROWS_PER_THREAD; j++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + j * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + // V cache pointer for this token and head dimension + const cache_t* v_ptr = v_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + row_idx * BLOCK_SIZE + i; + + accs[j] += attn_weight * float(*v_ptr); + } + } + } + } + + // Step 6: Warp-level reduction of attention output + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; + #pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + __syncthreads(); + + // Step 7: Block-level reduction and write output + float* out_smem = reinterpret_cast(shared_mem); + + // Each warp writes its partial results to shared memory + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + out_smem[warp_idx * HEAD_SIZE + row_idx] = accs[i]; + } + } + __syncthreads(); + + // Final reduction across warps and write output + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + + for (int i = thread_idx; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + #pragma unroll + for (int w = 0; w < NUM_WARPS; w++) { + acc += out_smem[w * HEAD_SIZE + i]; + } + from_float(out_ptr[i], acc); + } +} + +// +// Launcher function +// +// Handles type dispatch and kernel launch configuration +// + +void paged_attention_v1_launcher( + void* out, + const void* query, + const void* key_cache, + const void* value_cache, + int num_seqs, + int num_heads, + int num_kv_heads, + int head_size, + int block_size, + int max_num_blocks_per_seq, + const int* block_tables, + const int* seq_lens, + int max_seq_len, + float scale, + const float* alibi_slopes, + ggml_type q_type, + ggml_type kv_cache_type, + cudaStream_t stream) { + + // Determine thread block configuration + constexpr int NUM_THREADS = 128; + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + + // Calculate shared memory size + const int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, block_size) * block_size; + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_THREADS / WARP_SIZE / 2) * head_size * sizeof(float); + const int shared_mem_size = max(logits_size, outputs_size); + + // Compute strides + const int q_stride = num_heads * head_size; + const int kv_block_stride = num_kv_heads * head_size * block_size; + const int kv_head_stride = head_size * block_size; + + // Macro to dispatch kernel based on head size and block size + #define LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_kernel \ + <<>>( \ + (SCALAR_T*)out, (const SCALAR_T*)query, \ + (const CACHE_T*)key_cache, (const CACHE_T*)value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, \ + max_num_blocks_per_seq, alibi_slopes, \ + q_stride, kv_block_stride, kv_head_stride) + + // Dispatch for head size + #define DISPATCH_HEAD_SIZE(SCALAR_T, CACHE_T, BLOCK_SIZE) \ + switch (head_size) { \ + case 32: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 32, BLOCK_SIZE); break; \ + case 64: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 64, BLOCK_SIZE); break; \ + case 80: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 80, BLOCK_SIZE); break; \ + case 96: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 96, BLOCK_SIZE); break; \ + case 112: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 112, BLOCK_SIZE); break; \ + case 120: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 120, BLOCK_SIZE); break; \ + case 128: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 128, BLOCK_SIZE); break; \ + case 192: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 192, BLOCK_SIZE); break; \ + case 256: LAUNCH_PAGED_ATTENTION_V1(SCALAR_T, CACHE_T, 256, BLOCK_SIZE); break; \ + default: \ + fprintf(stderr, "Unsupported head size: %d\n", head_size); \ + GGML_ABORT("fatal error"); \ + } + + // Dispatch for block size + #define DISPATCH_BLOCK_SIZE(SCALAR_T, CACHE_T) \ + switch (block_size) { \ + case 8: DISPATCH_HEAD_SIZE(SCALAR_T, CACHE_T, 8); break; \ + case 16: DISPATCH_HEAD_SIZE(SCALAR_T, CACHE_T, 16); break; \ + case 32: DISPATCH_HEAD_SIZE(SCALAR_T, CACHE_T, 32); break; \ + default: \ + fprintf(stderr, "Unsupported block size: %d\n", block_size); \ + GGML_ABORT("fatal error"); \ + } + + // Type dispatch based on q_type and kv_cache_type + if (q_type == GGML_TYPE_F16 && kv_cache_type == GGML_TYPE_F16) { + DISPATCH_BLOCK_SIZE(half, half); + } else if (q_type == GGML_TYPE_F32 && kv_cache_type == GGML_TYPE_F32) { + DISPATCH_BLOCK_SIZE(float, float); + } else if (q_type == GGML_TYPE_F16 && kv_cache_type == GGML_TYPE_F32) { + DISPATCH_BLOCK_SIZE(half, float); + } else if (q_type == GGML_TYPE_F32 && kv_cache_type == GGML_TYPE_F16) { + DISPATCH_BLOCK_SIZE(float, half); + } else { + fprintf(stderr, "Unsupported data type combination: q_type=%d, kv_cache_type=%d\n", + q_type, kv_cache_type); + GGML_ABORT("fatal error"); + } + + #undef LAUNCH_PAGED_ATTENTION_V1 + #undef DISPATCH_HEAD_SIZE + #undef DISPATCH_BLOCK_SIZE + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace ggml_cuda_paged_attention + +#endif // GGML_USE_MUSA diff --git a/ggml/src/ggml-cuda/paged-attention-v2.cu b/ggml/src/ggml-cuda/paged-attention-v2.cu new file mode 100644 index 00000000000..684d2766874 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-attention-v2.cu @@ -0,0 +1,527 @@ +// PagedAttention is not yet supported on MUSA +#ifndef GGML_USE_MUSA + +#include "paged-attention.cuh" +#include "common.cuh" + +/** + * PagedAttention V2 CUDA Kernel Implementation + * + * Based on the PagedAttention implementation from vLLM: + * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cuh + * + * Copyright (c) 2023-2024 vLLM Project + * SPDX-License-Identifier: Apache-2.0 + * + * Adapted for GGML by llama.cpp contributors + */ + +namespace ggml_cuda_paged_attention { + +// +// Main PagedAttention V2 Kernel +// +// This kernel computes partial attention for one partition. +// The main difference from V1 is that it processes only a subset of the sequence +// and stores intermediate results (max_logits, exp_sums, partial outputs). +// + +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, + float* __restrict__ max_logits, + scalar_t* __restrict__ tmp_out, + const scalar_t* __restrict__ q, + const cache_t* __restrict__ k_cache, + const cache_t* __restrict__ v_cache, + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, + const int* __restrict__ seq_lens, + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int max_num_partitions = gridDim.z; + const int thread_idx = threadIdx.x; + + const int seq_len = seq_lens[seq_idx]; + if (partition_idx * PARTITION_SIZE >= seq_len) { + // This partition is beyond the sequence length + return; + } + + // Calculate range of blocks to process in this partition + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = PARTITION_SIZE / BLOCK_SIZE; + const int start_block_idx = partition_idx * num_blocks_per_partition; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // Shared memory for partial logits and reduction + extern __shared__ char shared_mem[]; + float* logits = reinterpret_cast(shared_mem); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + __shared__ float red_smem[2 * NUM_WARPS]; + + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + // Get KV head index (for GQA/MQA) + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + + // ALiBi bias (if applicable) + const float alibi_slope = alibi_slopes ? alibi_slopes[head_idx] : 0.0f; + + // Query pointer for this sequence and head + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + // Compute Q·K for tokens in this partition only + float qk_max = -FLT_MAX; + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + // Load K vectors and compute dot products + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int token_idx = block_idx * BLOCK_SIZE + i; + if (token_idx >= end_token_idx) break; + + // Compute Q·K dot product + // K cache layout: [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* k_ptr = k_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + i; // token position within block + + // Compute dot product between Q and K + // Each thread computes part of the dot product + float qk = 0.0f; + for (int elem_idx = thread_idx; elem_idx < HEAD_SIZE; elem_idx += NUM_THREADS) { + // Load K element for this token + // K is stored as [head_size, block_size], so offset is elem_idx * BLOCK_SIZE + const cache_t k_val = k_ptr[elem_idx * BLOCK_SIZE]; + + // Load Q element (from scalar_t array) + const scalar_t q_val = q_ptr[elem_idx]; + + // Accumulate dot product + qk += float(q_val) * float(k_val); + } + + // Reduce across all threads in the block + #pragma unroll + for (int mask = NUM_THREADS / 2; mask >= 1; mask /= 2) { + qk += SHFL_XOR_SYNC(qk, mask); + } + + // Apply scale + qk *= scale; + + // Add ALiBi bias if applicable + if (alibi_slope != 0.0f) { + qk += alibi_slope * (token_idx - seq_len + 1); + } + + // Store logit (only thread 0 writes after full reduction) + if (thread_idx == 0) { + logits[token_idx - start_token_idx] = qk; + } + + qk_max = fmaxf(qk_max, qk); + } + } + + // Warp and block level reduction to find max (same as V1) + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, SHFL_XOR_SYNC(qk_max, mask)); + } + qk_max = SHFL_SYNC(qk_max, 0); + + // Compute softmax (for this partition only) + float exp_sum = 0.0f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Store max_logit and exp_sum for this partition (for reduce kernel) + if (thread_idx == 0) { + const int idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + max_logits[idx] = qk_max; + exp_sums[idx] = exp_sum; + } + + // Don't normalize yet - will be done in reduce kernel + + // Compute partial attention output (softmax · V) + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + float accs[NUM_ROWS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.0f; + } + + // Compute partial attention output by multiplying softmax weights with V + // V cache layout: [num_blocks, num_kv_heads, head_size, block_size] + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int token_idx = block_idx * BLOCK_SIZE + i; + if (token_idx >= end_token_idx) break; + + // Get attention weight for this token + const float attn_weight = logits[token_idx - start_token_idx]; + + // Accumulate V vectors weighted by attention + #pragma unroll + for (int j = 0; j < NUM_ROWS_PER_THREAD; j++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + j * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + // V cache pointer for this token and head dimension + const cache_t* v_ptr = v_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + row_idx * BLOCK_SIZE + i; + + accs[j] += attn_weight * float(*v_ptr); + } + } + } + } + + // Warp-level reduction of attention output + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; + #pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + __syncthreads(); + + // Block-level reduction and output write (to temporary buffer) + float* out_smem = reinterpret_cast(shared_mem); + + // Each warp writes its partial results to shared memory + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + out_smem[warp_idx * HEAD_SIZE + row_idx] = accs[i]; + } + } + __syncthreads(); + + // Final reduction across warps and write to temporary output + scalar_t* tmp_out_ptr = tmp_out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + + for (int i = thread_idx; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + #pragma unroll + for (int w = 0; w < NUM_WARPS; w++) { + acc += out_smem[w * HEAD_SIZE + i]; + } + from_float(tmp_out_ptr[i], acc); + } +} + +// +// PagedAttention V2 Reduce Kernel +// +// Combines partial results from all partitions +// + +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, + const float* __restrict__ exp_sums, + const float* __restrict__ max_logits, + const scalar_t* __restrict__ tmp_out, + const int* __restrict__ seq_lens, + const int max_num_partitions) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int thread_idx = threadIdx.x; + + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + + // Find global max logit across all partitions + float global_max_logit = -FLT_MAX; + for (int i = 0; i < num_partitions; ++i) { + const int idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + i; + global_max_logit = fmaxf(global_max_logit, max_logits[idx]); + } + + // Share global max across threads + __shared__ float shared_global_max; + if (thread_idx == 0) { + shared_global_max = global_max_logit; + } + __syncthreads(); + global_max_logit = shared_global_max; + + // Compute rescaled exp_sum + float global_exp_sum = 0.0f; + for (int i = 0; i < num_partitions; ++i) { + const int idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + i; + float rescale = __expf(max_logits[idx] - global_max_logit); + global_exp_sum += exp_sums[idx] * rescale; + } + + // Share global exp_sum + __shared__ float shared_global_exp_sum; + if (thread_idx == 0) { + shared_global_exp_sum = global_exp_sum; + } + __syncthreads(); + global_exp_sum = shared_global_exp_sum; + + const float inv_global_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Combine partial outputs with proper rescaling + for (int elem_idx = thread_idx; elem_idx < HEAD_SIZE; elem_idx += NUM_THREADS) { + float acc = 0.0f; + + for (int i = 0; i < num_partitions; ++i) { + const int idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + i; + + const scalar_t* tmp_out_ptr = tmp_out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + i * HEAD_SIZE; + + float rescale = __expf(max_logits[idx] - global_max_logit); + float partial_val = float(tmp_out_ptr[elem_idx]); + acc += partial_val * rescale; + } + + // Normalize and write final output + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + from_float(out_ptr[elem_idx], acc * inv_global_sum); + } +} + +// +// Launcher function for V2 +// + +void paged_attention_v2_launcher( + void* out, + const void* query, + const void* key_cache, + const void* value_cache, + int num_seqs, + int num_heads, + int num_kv_heads, + int head_size, + int block_size, + int max_num_blocks_per_seq, + const int* block_tables, + const int* seq_lens, + int max_seq_len, + float scale, + const float* alibi_slopes, + ggml_type q_type, + ggml_type kv_cache_type, + ggml_cuda_pool & pool, + cudaStream_t stream) { + + constexpr int NUM_THREADS = 128; + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + + // Allocate temporary buffers using GGML's pool allocator + // These will be automatically freed when they go out of scope at function exit + const size_t exp_sum_count = num_seqs * num_heads * max_num_partitions; + const size_t max_logit_count = num_seqs * num_heads * max_num_partitions; + const size_t tmp_out_count = num_seqs * num_heads * max_num_partitions * head_size; + + ggml_cuda_pool_alloc exp_sums_alloc(pool, exp_sum_count); + ggml_cuda_pool_alloc max_logits_alloc(pool, max_logit_count); + + // Use uint8_t for tmp_out to handle both half and float types generically + const size_t tmp_out_bytes = tmp_out_count * (q_type == GGML_TYPE_F16 ? sizeof(half) : sizeof(float)); + ggml_cuda_pool_alloc tmp_out_alloc(pool, tmp_out_bytes); + + float* exp_sums = exp_sums_alloc.get(); + float* max_logits = max_logits_alloc.get(); + void* tmp_out = tmp_out_alloc.get(); + + // Launch main V2 kernel + { + dim3 grid(num_heads, num_seqs, max_num_partitions); + dim3 block(NUM_THREADS); + + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_THREADS / WARP_SIZE / 2) * head_size * sizeof(float); + const int shared_mem_size = max(logits_size, outputs_size); + + const int q_stride = num_heads * head_size; + const int kv_block_stride = num_kv_heads * head_size * block_size; + const int kv_head_stride = head_size * block_size; + + // Macro to dispatch V2 main kernel based on head size and block size + #define LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_kernel \ + <<>>( \ + exp_sums, max_logits, (SCALAR_T*)tmp_out, \ + (const SCALAR_T*)query, (const CACHE_T*)key_cache, (const CACHE_T*)value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, \ + max_num_blocks_per_seq, alibi_slopes, \ + q_stride, kv_block_stride, kv_head_stride) + + // Dispatch for head size + #define DISPATCH_V2_HEAD_SIZE(SCALAR_T, CACHE_T, BLOCK_SIZE) \ + switch (head_size) { \ + case 32: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 32, BLOCK_SIZE); break; \ + case 64: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 64, BLOCK_SIZE); break; \ + case 80: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 80, BLOCK_SIZE); break; \ + case 96: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 96, BLOCK_SIZE); break; \ + case 112: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 112, BLOCK_SIZE); break; \ + case 120: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 120, BLOCK_SIZE); break; \ + case 128: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 128, BLOCK_SIZE); break; \ + case 192: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 192, BLOCK_SIZE); break; \ + case 256: LAUNCH_PAGED_ATTENTION_V2(SCALAR_T, CACHE_T, 256, BLOCK_SIZE); break; \ + default: \ + fprintf(stderr, "Unsupported head size: %d\n", head_size); \ + GGML_ABORT("fatal error"); \ + } + + // Dispatch for block size + #define DISPATCH_V2_BLOCK_SIZE(SCALAR_T, CACHE_T) \ + switch (block_size) { \ + case 8: DISPATCH_V2_HEAD_SIZE(SCALAR_T, CACHE_T, 8); break; \ + case 16: DISPATCH_V2_HEAD_SIZE(SCALAR_T, CACHE_T, 16); break; \ + case 32: DISPATCH_V2_HEAD_SIZE(SCALAR_T, CACHE_T, 32); break; \ + default: \ + fprintf(stderr, "Unsupported block size: %d\n", block_size); \ + GGML_ABORT("fatal error"); \ + } + + // Type dispatch based on q_type and kv_cache_type + if (q_type == GGML_TYPE_F16 && kv_cache_type == GGML_TYPE_F16) { + DISPATCH_V2_BLOCK_SIZE(half, half); + } else if (q_type == GGML_TYPE_F32 && kv_cache_type == GGML_TYPE_F32) { + DISPATCH_V2_BLOCK_SIZE(float, float); + } else if (q_type == GGML_TYPE_F16 && kv_cache_type == GGML_TYPE_F32) { + DISPATCH_V2_BLOCK_SIZE(half, float); + } else if (q_type == GGML_TYPE_F32 && kv_cache_type == GGML_TYPE_F16) { + DISPATCH_V2_BLOCK_SIZE(float, half); + } else { + fprintf(stderr, "Unsupported data type combination: q_type=%d, kv_cache_type=%d\n", + q_type, kv_cache_type); + GGML_ABORT("fatal error"); + } + + #undef LAUNCH_PAGED_ATTENTION_V2 + #undef DISPATCH_V2_HEAD_SIZE + #undef DISPATCH_V2_BLOCK_SIZE + + CUDA_CHECK(cudaGetLastError()); + } + + // Launch reduce kernel + { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(NUM_THREADS); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + // Macro to dispatch V2 reduce kernel based on head size + #define LAUNCH_PAGED_ATTENTION_V2_REDUCE(SCALAR_T, HEAD_SIZE) \ + paged_attention_v2_reduce_kernel \ + <<>>( \ + (SCALAR_T*)out, exp_sums, max_logits, (const SCALAR_T*)tmp_out, \ + seq_lens, max_num_partitions) + + // Dispatch reduce kernel for different head sizes and data types + if (q_type == GGML_TYPE_F16) { + switch (head_size) { + case 32: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 32); break; + case 64: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 64); break; + case 80: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 80); break; + case 96: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 96); break; + case 112: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 112); break; + case 120: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 120); break; + case 128: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 128); break; + case 192: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 192); break; + case 256: LAUNCH_PAGED_ATTENTION_V2_REDUCE(half, 256); break; + default: + fprintf(stderr, "Unsupported head size: %d\n", head_size); + GGML_ABORT("fatal error"); + } + } else if (q_type == GGML_TYPE_F32) { + switch (head_size) { + case 32: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 32); break; + case 64: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 64); break; + case 80: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 80); break; + case 96: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 96); break; + case 112: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 112); break; + case 120: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 120); break; + case 128: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 128); break; + case 192: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 192); break; + case 256: LAUNCH_PAGED_ATTENTION_V2_REDUCE(float, 256); break; + default: + fprintf(stderr, "Unsupported head size: %d\n", head_size); + GGML_ABORT("fatal error"); + } + } else { + fprintf(stderr, "Unsupported query data type: %d\n", q_type); + GGML_ABORT("fatal error"); + } + + #undef LAUNCH_PAGED_ATTENTION_V2_REDUCE + + CUDA_CHECK(cudaGetLastError()); + } + + // Temporary buffers are automatically freed when pool_alloc objects go out of scope +} + +} // namespace ggml_cuda_paged_attention + +#endif // GGML_USE_MUSA diff --git a/ggml/src/ggml-cuda/paged-attention.cuh b/ggml/src/ggml-cuda/paged-attention.cuh new file mode 100644 index 00000000000..6590ad43e73 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-attention.cuh @@ -0,0 +1,256 @@ +#pragma once + +/** + * PagedAttention CUDA Kernel Header + * + * Based on the PagedAttention implementation from vLLM: + * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cuh + * + * Copyright (c) 2023-2024 vLLM Project + * SPDX-License-Identifier: Apache-2.0 + * + * Adapted for GGML by llama.cpp contributors + */ + +#include "common.cuh" + +// WARP_SIZE is already defined in common.cuh +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace ggml_cuda_paged_attention { + +// Partition size for PagedAttention V2 +constexpr int PARTITION_SIZE = 512; + +// Supported head sizes +constexpr int SUPPORTED_HEAD_SIZES[] = {32, 64, 80, 96, 112, 120, 128, 192, 256}; + +// +// Helper structures and functions +// + +// Vector types for efficient memory access +template +struct Vec { + using Type = T; +}; + +template<> struct Vec { using Type = half; }; +template<> struct Vec { using Type = half2; }; +template<> struct Vec { using Type = uint2; }; // 4 halfs = 64 bits +template<> struct Vec { using Type = uint4; }; // 8 halfs = 128 bits + +template<> struct Vec { using Type = float; }; +template<> struct Vec { using Type = float2; }; +template<> struct Vec { using Type = float4; }; + +// Float vector type conversion +template +struct FloatVec { + using Type = L_vec; +}; + +// Warp shuffle utilities +#define SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask, WARP_SIZE) +#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane, WARP_SIZE) + +// Block-level reduction +template +__inline__ __device__ float block_sum(float* red_smem, float sum) { + // Decompose thread index into warp / lane + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Warp-level reduction + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store to shared memory + if (lane == 0) { + red_smem[warp] = sum; + } + __syncthreads(); + + // Final reduction across warps + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast result + return SHFL_SYNC(sum, 0); +} + +// Dot product helpers +template +__inline__ __device__ float dot(T a, T b) { + // Default implementation for scalar types + return float(a) * float(b); +} + +__inline__ __device__ float dot(half2 a, half2 b) { + float2 a_f = __half22float2(a); + float2 b_f = __half22float2(b); + return a_f.x * b_f.x + a_f.y * b_f.y; +} + +// Convert from float +template +__inline__ __device__ void from_float(T& dst, float src) { + dst = T(src); +} + +__inline__ __device__ void from_float(half& dst, float src) { + dst = __float2half(src); +} + +__inline__ __device__ void from_float(half2& dst, float src) { + dst = __float2half2_rn(src); +} + +// Zero initialization +template +__inline__ __device__ void zero(T& val) { + val = T(0); +} + +// +// PagedAttention V1 Kernel +// +// For shorter sequences (≤8192 tokens) +// Each thread block processes one head of one sequence +// + +template // Threads per block (e.g., 128) +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] or nullptr + const int q_stride, // stride for q + const int kv_block_stride, // stride between blocks in cache + const int kv_head_stride); // stride between heads in cache + +// +// PagedAttention V2 Kernel +// +// For longer sequences (>8192 tokens) +// Uses partitioning to avoid shared memory limits +// + +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, + const cache_t* __restrict__ k_cache, + const cache_t* __restrict__ v_cache, + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, + const int* __restrict__ seq_lens, + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride); + +// +// PagedAttention V2 Reduce Kernel +// +// Combines partial results from V2 main kernel +// + +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions); + +// +// Launcher functions (to be called from GGML backend) +// + +// Launch PagedAttention V1 +void paged_attention_v1_launcher( + void* out, // Output tensor + const void* query, // Query tensor + const void* key_cache, // Key cache (paged) + const void* value_cache, // Value cache (paged) + int num_seqs, + int num_heads, + int num_kv_heads, + int head_size, + int block_size, + int max_num_blocks_per_seq, + const int* block_tables, + const int* seq_lens, + int max_seq_len, + float scale, + const float* alibi_slopes, // Can be nullptr + ggml_type q_type, // Query data type + ggml_type kv_cache_type, // KV cache data type + cudaStream_t stream); + +// Launch PagedAttention V2 +void paged_attention_v2_launcher( + void* out, + const void* query, + const void* key_cache, + const void* value_cache, + int num_seqs, + int num_heads, + int num_kv_heads, + int head_size, + int block_size, + int max_num_blocks_per_seq, + const int* block_tables, + const int* seq_lens, + int max_seq_len, + float scale, + const float* alibi_slopes, + ggml_type q_type, + ggml_type kv_cache_type, + ggml_cuda_pool & pool, + cudaStream_t stream); + +// Helper: Decide which version to use +inline bool should_use_v1(int max_seq_len, int num_seqs, int num_heads) { + const int max_num_partitions = (max_seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + // Use V1 if: + // - Sequence is short enough (≤8192) AND + // - Either we have only 1 partition OR we have lots of parallelism + return max_seq_len <= 8192 && (max_num_partitions == 1 || num_seqs * num_heads > 512); +} + +} // namespace ggml_cuda_paged_attention diff --git a/ggml/src/ggml-cuda/paged-cpy.cu b/ggml/src/ggml-cuda/paged-cpy.cu new file mode 100644 index 00000000000..e2edb97a4e6 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-cpy.cu @@ -0,0 +1,151 @@ +/** + * GGML CUDA Backend for PagedCopy (GGML_OP_PAGED_CPY) + * + * This file implements the CUDA kernel for copying K/V data to paged cache blocks. + * Similar to vLLM's reshape_and_cache kernel. + */ + +#ifndef GGML_USE_MUSA + +#include "common.cuh" + +// CUDA kernel for copying K/V data to paged blocks +// Inspired by vLLM's reshape_and_cache kernel +template +__global__ void paged_cpy_kernel( + const T* __restrict__ kv_cur, // [head_size, n_heads, n_tokens] + T* __restrict__ kv_cache, // [num_blocks, n_kv_heads, head_size, block_size] + const int32_t* __restrict__ slot_idxs, // [n_tokens] - slot index for each token + int head_size, + int n_heads, + int n_kv_heads, + int n_tokens, + int block_size) { + + // Each block processes one token + const int token_idx = blockIdx.x; + if (token_idx >= n_tokens) return; + + // Get the cache slot for this token + const int slot_idx = slot_idxs[token_idx]; + const int block_id = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + // GQA: map query head to kv head + const int head_ratio = n_heads / n_kv_heads; + const int head_idx = threadIdx.y; // which head + const int kv_head_idx = head_idx / head_ratio; + + // Each thread copies one element of head_size + const int elem_idx = threadIdx.x; + if (elem_idx >= head_size) return; + + // Source: kv_cur[elem_idx, head_idx, token_idx] + // Layout: [head_size, n_heads, n_tokens] + const int src_idx = elem_idx + head_idx * head_size + token_idx * head_size * n_heads; + const T value = kv_cur[src_idx]; + + // Destination: kv_cache[block_id, kv_head_idx, elem_idx, block_offset] + // Layout: [num_blocks, n_kv_heads, head_size, block_size] + const int dst_idx = block_id * (n_kv_heads * head_size * block_size) + + kv_head_idx * (head_size * block_size) + + elem_idx * block_size + + block_offset; + + kv_cache[dst_idx] = value; +} + +// Launcher function +void ggml_cuda_op_paged_cpy( + ggml_backend_cuda_context & ctx, + ggml_tensor * dst) { + + const ggml_tensor * kv_cur = dst->src[0]; // [head_size, n_heads, n_tokens] + const ggml_tensor * kv_cache = dst->src[1]; // [num_blocks, n_kv_heads, head_size, block_size] + const ggml_tensor * slot_idxs = dst->src[2]; // [n_tokens] (int32) + + // Get dimensions + const int head_size = kv_cur->ne[0]; + const int n_heads = kv_cur->ne[1]; + const int n_tokens = kv_cur->ne[2]; + + const int num_blocks = kv_cache->ne[0]; + const int n_kv_heads = kv_cache->ne[1]; + const int block_size = kv_cache->ne[3]; + + GGML_ASSERT(head_size == kv_cache->ne[2]); + GGML_ASSERT(n_tokens == slot_idxs->ne[0]); + GGML_ASSERT(slot_idxs->type == GGML_TYPE_I32); + + // Skip if there are no tokens to copy + if (n_tokens == 0) { + return; + } + + // Get pointers + const void * kv_cur_ptr = kv_cur->data; + void * kv_cache_ptr = kv_cache->data; + const int32_t * slot_idxs_ptr = (const int32_t *)slot_idxs->data; + + // Get CUDA stream + cudaStream_t stream = ctx.stream(); + + // Launch kernel + // Grid: one block per token + // Block: head_size threads in x, n_heads threads in y + dim3 grid(n_tokens); + dim3 block(head_size, n_heads); + + // Ensure block dimensions are valid + GGML_ASSERT(head_size * n_heads <= 1024); // max threads per block + + // Debug logging + fprintf(stderr, "paged_cpy: head_size=%d, n_heads=%d, n_kv_heads=%d, n_tokens=%d, block_size=%d\n", + head_size, n_heads, n_kv_heads, n_tokens, block_size); + fprintf(stderr, "paged_cpy: kv_cur dims=[%lld,%lld,%lld,%lld], kv_cache dims=[%lld,%lld,%lld,%lld]\n", + kv_cur->ne[0], kv_cur->ne[1], kv_cur->ne[2], kv_cur->ne[3], + kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3]); + fprintf(stderr, "paged_cpy: pointers: kv_cur=%p, kv_cache=%p, slot_idxs=%p\n", + kv_cur_ptr, kv_cache_ptr, slot_idxs_ptr); + + if (kv_cur->type == GGML_TYPE_F16) { + paged_cpy_kernel<<>>( + (const half *)kv_cur_ptr, + (half *)kv_cache_ptr, + slot_idxs_ptr, + head_size, + n_heads, + n_kv_heads, + n_tokens, + block_size); + } else if (kv_cur->type == GGML_TYPE_F32) { + paged_cpy_kernel<<>>( + (const float *)kv_cur_ptr, + (float *)kv_cache_ptr, + slot_idxs_ptr, + head_size, + n_heads, + n_kv_heads, + n_tokens, + block_size); + } else { + GGML_ABORT("Unsupported type for paged_cpy"); + } + + CUDA_CHECK(cudaGetLastError()); +} + +#else // GGML_USE_MUSA + +// Stub for MUSA +#include "common.cuh" + +void ggml_cuda_op_paged_cpy( + ggml_backend_cuda_context & ctx, + ggml_tensor * dst) { + GGML_UNUSED(ctx); + GGML_UNUSED(dst); + GGML_ABORT("PagedCopy is not yet supported on MUSA"); +} + +#endif // GGML_USE_MUSA diff --git a/ggml/src/ggml-cuda/paged-cpy.cuh b/ggml/src/ggml-cuda/paged-cpy.cuh new file mode 100644 index 00000000000..901b04c3484 --- /dev/null +++ b/ggml/src/ggml-cuda/paged-cpy.cuh @@ -0,0 +1,6 @@ +#pragma once + +#include "common.cuh" + +// CUDA backend function for GGML_OP_PAGED_CPY +void ggml_cuda_op_paged_cpy(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e93..326d5e77e55 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -997,6 +997,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", + "PAGED_ATTENTION", + "PAGED_CPY", "SSM_CONV", "SSM_SCAN", "WIN_PART", @@ -1024,7 +1026,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1106,6 +1108,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn_ext(x)", "flash_attn_back(x)", + "paged_attn(q,k,v,bt,sl)", + "paged_cpy(kv_cur,kv_cache,slot_idxs)", "ssm_conv(x)", "ssm_scan(x)", "win_part(x)", @@ -1133,7 +1137,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5268,6 +5272,68 @@ void ggml_flash_attn_ext_add_sinks( a->src[4] = sinks; } +// ggml_paged_attention + +struct ggml_tensor * ggml_paged_attention( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k_cache, + struct ggml_tensor * v_cache, + struct ggml_tensor * block_tables, + struct ggml_tensor * seq_lens, + int32_t block_size, + float scale) { + + // Validate inputs + GGML_ASSERT(q->ne[0] == k_cache->ne[2]); // head_size must match + GGML_ASSERT(k_cache->ne[2] == v_cache->ne[2]); // k and v head_size must match + GGML_ASSERT(block_tables->type == GGML_TYPE_I32); + GGML_ASSERT(seq_lens->type == GGML_TYPE_I32); + + // Output shape: [head_size, n_heads, n_tokens] + // Same as input query shape + int64_t ne[4] = { q->ne[0], q->ne[1], q->ne[2], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, q->type, 4, ne); + + // Store parameters: scale and block_size + float params[] = { scale, (float)block_size }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_PAGED_ATTENTION; + result->src[0] = q; + result->src[1] = k_cache; + result->src[2] = v_cache; + result->src[3] = block_tables; + result->src[4] = seq_lens; + + return result; +} + +// ggml_paged_cpy + +struct ggml_tensor * ggml_paged_cpy( + struct ggml_context * ctx, + struct ggml_tensor * kv_cur, + struct ggml_tensor * kv_cache, + struct ggml_tensor * slot_idxs) { + + // Validate inputs + GGML_ASSERT(kv_cur->ne[0] == kv_cache->ne[2]); // head_size must match + GGML_ASSERT(slot_idxs->type == GGML_TYPE_I32); + GGML_ASSERT(slot_idxs->ne[0] == kv_cur->ne[2]); // one slot idx per token + + // Output shape: same as kv_cache (operation modifies it in-place) + // But we return kv_cache itself to add this op to the graph + struct ggml_tensor * result = ggml_view_tensor(ctx, kv_cache); + + result->op = GGML_OP_PAGED_CPY; + result->src[0] = kv_cur; + result->src[1] = kv_cache; + result->src[2] = slot_idxs; + + return result; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( diff --git a/include/llama.h b/include/llama.h index b52eaacfa7e..7ec2fdef845 100644 --- a/include/llama.h +++ b/include/llama.h @@ -363,6 +363,7 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + bool use_paged_attention; // use PagedAttention for KV cache (experimental, requires CUDA) }; // model quantization parameters diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 67c7807e092..206e8784316 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(llama llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-paged.cpp llama-memory.cpp llama-memory-hybrid.cpp llama-memory-recurrent.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e04f0fc4f9a..c46bfa37b4e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -89,6 +89,18 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + fprintf(stderr, "DEBUG llama_context line 91: before assignment, cparams.use_paged_attention = %s\n", + cparams.use_paged_attention ? "true" : "false"); + fprintf(stderr, "DEBUG llama_context line 91: params.use_paged_attention = %s\n", + params.use_paged_attention ? "true" : "false"); + cparams.use_paged_attention = params.use_paged_attention; + fprintf(stderr, "DEBUG llama_context line 92: after assignment, cparams.use_paged_attention = %s\n", + cparams.use_paged_attention ? "true" : "false"); + + if (params.use_paged_attention) { + LLAMA_LOG_INFO("%s: params.use_paged_attention = true, cparams.use_paged_attention = %s\n", + __func__, cparams.use_paged_attention ? "true" : "false"); + } // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -309,7 +321,13 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + fprintf(stderr, "DEBUG llama_context line 318: checking cparams.use_paged_attention = %s\n", + cparams.use_paged_attention ? "true" : "false"); + if (cparams.use_paged_attention) { + // PagedAttention is enabled, disable Flash Attention + cparams.flash_attn = false; + LLAMA_LOG_INFO("%s: PagedAttention is enabled, Flash Attention disabled\n", __func__); + } else if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); if (!gf) { throw std::runtime_error("failed to split graph for Flash Attention check"); @@ -344,6 +362,10 @@ llama_context::llama_context( if (ggml_is_quantized(params.type_v)) { throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } + } else if (cparams.use_paged_attention) { + // PagedAttention and Flash Attention are incompatible + cparams.flash_attn = false; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to disabled (PagedAttention is enabled)\n", __func__); } else { cparams.flash_attn = true; LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); @@ -2323,6 +2345,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.use_paged_attention =*/ false, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcef8fa9760..c77c53b1271 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,6 +30,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool use_paged_attention; bool no_perf; bool warmup; bool op_offload; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d012e09aba..51b38c4795e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -6,6 +6,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-paged.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -421,6 +422,40 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_paged::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + mctx->set_input_v_idxs(self_v_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + // Populate block tables and sequence lengths for PagedAttention + const auto * paged_cache = mctx->get_kv_paged(); + if (paged_cache) { + if (block_tables) { + paged_cache->populate_block_tables_tensor(block_tables); + } + if (seq_lens) { + paged_cache->populate_seq_lens_tensor(seq_lens); + } + } +} + +bool llm_graph_input_attn_paged::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + return res; +} + void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(cross_kq_mask); @@ -1347,18 +1382,63 @@ ggml_tensor * llm_graph_context::build_attn_mha( int il) const { const bool v_trans = v->nb[1] > v->nb[2]; - // split the batch into streams if needed - const auto n_stream = k->ne[3]; + ggml_tensor * cur; - q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); + // PagedAttention path (highest priority) + // PagedAttention uses unpermuted tensors: [head_size, n_heads, n_tokens, n_seqs] + if (cparams.use_paged_attention) { + // Get paged cache from context + const auto * paged_ctx = dynamic_cast(mctx); + GGML_ASSERT(paged_ctx != nullptr && "use_paged_attention is true but context is not paged"); + const auto * paged_cache = paged_ctx->get_kv_paged(); + GGML_ASSERT(paged_cache != nullptr && "paged context has no cache"); + + // Get K and V cache blocks for this layer + ggml_tensor * k_blocks = paged_cache->get_k_blocks(il); + ggml_tensor * v_blocks = paged_cache->get_v_blocks(il); + GGML_ASSERT(k_blocks != nullptr && v_blocks != nullptr); + + // Get block tables and seq_lens from graph input (built during graph construction) + // These will be populated with actual data in set_input() before execution + ggml_tensor * block_tables_tensor = nullptr; + ggml_tensor * seq_lens_tensor = nullptr; + + // Try to get from input if available (when called from build_attn with paged input) + // Otherwise build them here (fallback for direct calls) + if (false) { // TODO: need to pass inp_paged to build_attn_mha + // Would get from inp_paged->block_tables and inp_paged->seq_lens + } else { + block_tables_tensor = paged_cache->build_block_tables_tensor(ctx0); + seq_lens_tensor = paged_cache->build_seq_lens_tensor(ctx0); + } + GGML_ASSERT(block_tables_tensor != nullptr && seq_lens_tensor != nullptr); + + // Call paged attention operation + // Note: q is already permuted to [head_size, n_tokens, n_heads, n_seqs] + cur = ggml_paged_attention( + ctx0, + q, + k_blocks, + v_blocks, + block_tables_tensor, + seq_lens_tensor, + paged_cache->get_block_size(), + kq_scale + ); + cb(cur, LLAMA_TENSOR_NAME_PAGED_ATTN, il); - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - k = ggml_permute(ctx0, k, 0, 2, 1, 3); - v = ggml_permute(ctx0, v, 0, 2, 1, 3); + // Reshape to match expected output format: [head_size * n_heads, n_tokens * n_seqs] + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + } else { + // For non-paged paths, split batch into streams and apply permutation + const auto n_stream = k->ne[3]; + q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); - ggml_tensor * cur; + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); - if (cparams.flash_attn && kq_b == nullptr) { + if (cparams.flash_attn && kq_b == nullptr) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1462,6 +1542,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( // all nodes between the KV store and the attention output are run on the CPU ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu); } + } } ggml_build_forward_expand(gf, cur); @@ -1578,6 +1659,18 @@ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const { return (llm_graph_input_attn_kv *) res->add_input(std::move(inp)); } +llm_graph_input_i * llm_graph_context::build_attn_inp() const { + // Check if we're using paged attention + const auto * mctx_paged = dynamic_cast(mctx); + if (mctx_paged) { + // Return paged input + return build_attn_inp_paged(); + } + + // Fall back to regular KV cache + return build_attn_inp_kv(); +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv * inp, ggml_tensor * wo, @@ -1632,6 +1725,82 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_paged * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + GGML_UNUSED(kq_b); + GGML_UNUSED(sinks); + GGML_UNUSED(v_mla); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + const auto * paged_cache = mctx_cur->get_kv_paged(); + GGML_ASSERT(paged_cache != nullptr); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + const auto & v_idxs = inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); + } + + // Get K and V cache blocks for this layer + ggml_tensor * k_blocks = paged_cache->get_k_blocks(il); + ggml_tensor * v_blocks = paged_cache->get_v_blocks(il); + GGML_ASSERT(k_blocks != nullptr && v_blocks != nullptr); + + // Use block_tables and seq_lens from input (populated in set_input()) + GGML_ASSERT(inp->block_tables != nullptr && inp->seq_lens != nullptr); + + // Call PagedAttention operation + // q_cur is unpermuted: [head_size, n_heads, n_tokens, n_seqs] + ggml_tensor * cur = ggml_paged_attention( + ctx0, + q_cur, + k_blocks, + v_blocks, + inp->block_tables, + inp->seq_lens, + paged_cache->get_block_size(), + kq_scale + ); + cb(cur, "kqv_paged_attn", il); + + // Reshape to match expected output format: [head_size * n_heads, n_tokens * n_seqs] + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -1793,6 +1962,33 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } +llm_graph_input_attn_paged * llm_graph_context::build_attn_inp_paged() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + // Build block tables and seq lens tensors for PagedAttention + const auto * paged_cache = mctx_cur->get_kv_paged(); + if (paged_cache) { + inp->block_tables = paged_cache->build_block_tables_tensor(ctx0); + inp->seq_lens = paged_cache->build_seq_lens_tensor(ctx0); + } + + return (llm_graph_input_attn_paged *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, ggml_tensor * state_copy_main, diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f679..30612f4be46 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -21,6 +21,7 @@ struct llama_memory_context_i; class llama_kv_cache_context; class llama_kv_cache_iswa_context; +class llama_kv_cache_paged_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -346,6 +347,42 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { const llama_kv_cache_iswa_context * mctx; }; +class llm_graph_input_attn_paged : public llm_graph_input_i { +public: + llm_graph_input_attn_paged( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_paged_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_paged() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * block_tables = nullptr; // I32 [max_blocks_per_seq, n_seqs] + ggml_tensor * seq_lens = nullptr; // I32 [n_seqs] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_paged_context * mctx; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -719,6 +756,7 @@ struct llm_graph_context { int il) const; llm_graph_input_attn_kv * build_attn_inp_kv() const; + llm_graph_input_i * build_attn_inp() const; ggml_tensor * build_attn( llm_graph_input_attn_kv * inp, @@ -735,6 +773,21 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; + llm_graph_input_attn_paged * build_attn_inp_paged() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_paged * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( llm_graph_input_attn_kv_iswa * inp, diff --git a/src/llama-impl.h b/src/llama-impl.h index c5163e9225a..c5aa6e01b59 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -61,3 +61,4 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); #define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_PAGED_ATTN "__paged_attn__" diff --git a/src/llama-kv-cache-paged.cpp b/src/llama-kv-cache-paged.cpp new file mode 100644 index 00000000000..2780e3a2dfc --- /dev/null +++ b/src/llama-kv-cache-paged.cpp @@ -0,0 +1,1061 @@ +#include "llama-kv-cache-paged.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-hparams.h" +#include "llama-model.h" +#include "llama-kv-cache.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_paged_context implementation +// + +llama_kv_cache_paged_context::llama_kv_cache_paged_context(llama_memory_status status) + : status(status), kv_paged(nullptr), ubatch() { + fprintf(stderr, "llama_kv_cache_paged_context::llama_kv_cache_paged_context(status=%d) called\n", status); + // ubatch is value-initialized to zero +} + +llama_kv_cache_paged_context::llama_kv_cache_paged_context(llama_kv_cache_paged * kv_paged) + : status(LLAMA_MEMORY_STATUS_SUCCESS), kv_paged(kv_paged), ubatch() { + fprintf(stderr, "llama_kv_cache_paged_context::llama_kv_cache_paged_context(kv_paged=%p) called\n", (void*)kv_paged); + // ubatch is value-initialized to zero +} + +// Stub implementations for llama_kv_cache_context-like interface +// These are called by graph building code via static_cast +// TODO: Implement proper PagedAttention logic for these methods + +ggml_tensor * llama_kv_cache_paged_context::get_k(ggml_context * ctx, int32_t il) const { + GGML_UNUSED(ctx); + if (!kv_paged) { + fprintf(stderr, "ERROR: llama_kv_cache_paged_context::get_k() called with null kv_paged\n"); + return nullptr; + } + // Return the full paged K cache tensor for this layer + // The PagedAttention kernel will handle block indexing + return kv_paged->get_k_blocks(il); +} + +ggml_tensor * llama_kv_cache_paged_context::get_v(ggml_context * ctx, int32_t il) const { + GGML_UNUSED(ctx); + if (!kv_paged) { + fprintf(stderr, "ERROR: llama_kv_cache_paged_context::get_v() called with null kv_paged\n"); + return nullptr; + } + // Return the full paged V cache tensor for this layer + // The PagedAttention kernel will handle block indexing + return kv_paged->get_v_blocks(il); +} + +ggml_tensor * llama_kv_cache_paged_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { + if (!kv_paged) { + fprintf(stderr, "ERROR: llama_kv_cache_paged_context::cpy_k() called with null kv_paged\n"); + return nullptr; + } + + // Get K cache blocks for this layer + auto * k_cache = kv_paged->get_k_blocks(il); + if (!k_cache) { + return nullptr; + } + + // Use ggml_paged_cpy to copy K data to paged cache blocks + // k_cur shape: [head_size, n_heads, n_tokens] + // k_cache shape: [num_blocks, n_kv_heads, head_size, block_size] + // k_idxs shape: [n_tokens] - slot index for each token + return ggml_paged_cpy(ctx, k_cur, k_cache, k_idxs); +} + +ggml_tensor * llama_kv_cache_paged_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { + if (!kv_paged) { + fprintf(stderr, "ERROR: llama_kv_cache_paged_context::cpy_v() called with null kv_paged\n"); + return nullptr; + } + + // Get V cache blocks for this layer + auto * v_cache = kv_paged->get_v_blocks(il); + if (!v_cache) { + return nullptr; + } + + // Use ggml_paged_cpy to copy V data to paged cache blocks + // v_cur shape: [head_size, n_heads, n_tokens] + // v_cache shape: [num_blocks, n_kv_heads, head_size, block_size] + // v_idxs shape: [n_tokens] - slot index for each token + return ggml_paged_cpy(ctx, v_cur, v_cache, v_idxs); +} + +ggml_tensor * llama_kv_cache_paged_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + // TODO: Proper paged block index calculation + // For now, create a simple sequential index tensor + // This won't work correctly for PagedAttention but allows graph building to proceed + const int64_t n_tokens = ubatch.n_tokens; + auto * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(result, "k_idxs_paged"); + return result; +} + +ggml_tensor * llama_kv_cache_paged_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + // TODO: Proper paged block index calculation + // For now, create a simple sequential index tensor + // This won't work correctly for PagedAttention but allows graph building to proceed + const int64_t n_tokens = ubatch.n_tokens; + auto * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(result, "v_idxs_paged"); + return result; +} + +void llama_kv_cache_paged_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + if (!dst || !ubatch) return; + + // TODO: Proper paged block indexing + // For now, fill with sequential indices + GGML_ASSERT(dst->type == GGML_TYPE_I32); + int32_t * data = (int32_t *) dst->data; + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + data[i] = i; + } +} + +void llama_kv_cache_paged_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + if (!dst || !ubatch) return; + + // TODO: Proper paged block indexing + // For now, fill with sequential indices + GGML_ASSERT(dst->type == GGML_TYPE_I32); + int32_t * data = (int32_t *) dst->data; + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + data[i] = i; + } +} + +void llama_kv_cache_paged_context::set_input_k_shift(ggml_tensor * dst) const { + // K shifting not supported with PagedAttention + GGML_UNUSED(dst); +} + +void llama_kv_cache_paged_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + if (!dst || !ubatch) return; + + // TODO: Proper PagedAttention mask handling + // For now, create a simple causal mask + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_kv = dst->ne[0]; // KV sequence length + + float * data = (float *) dst->data; + + if (causal_attn) { + // Causal mask: can attend to current and previous tokens + for (int64_t i = 0; i < n_tokens; ++i) { + for (int64_t j = 0; j < n_kv; ++j) { + data[i * n_kv + j] = (j <= i) ? 0.0f : -INFINITY; + } + } + } else { + // No mask: can attend to all tokens + for (int64_t i = 0; i < n_tokens * n_kv; ++i) { + data[i] = 0.0f; + } + } +} + +void llama_kv_cache_paged_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + // Position bucketing not used with basic PagedAttention + GGML_UNUSED(dst); + GGML_UNUSED(ubatch); +} + +// +// llama_kv_cache_paged implementation +// + +llama_kv_cache_paged::llama_kv_cache_paged( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t block_size, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) + : model(model), + hparams(model.hparams), + type_k(type_k), + type_v(type_v), + n_seq_max(n_seq_max), + block_size(block_size), + num_blocks((kv_size + block_size - 1) / block_size) { // ceil division + + GGML_ASSERT(block_size > 0 && block_size <= 256); + GGML_ASSERT((block_size & (block_size - 1)) == 0 && "block_size must be power of 2"); + + // Check environment variable for debug output + const char * debug_env = std::getenv("LLAMA_KV_CACHE_DEBUG"); + if (debug_env) { + debug = std::atoi(debug_env); + } + + if (debug > 0) { + fprintf(stderr, "%s: initializing paged KV cache with %u blocks of size %u (total capacity: %u tokens)\n", + __func__, num_blocks, block_size, num_blocks * block_size); + } + + // Build layer list (same as standard KV cache) + const int32_t n_layer = hparams.n_layer; + + for (int32_t il = 0; il < n_layer; ++il) { + if (filter && !filter(il)) { + continue; + } + + // Check if this layer should reuse memory from another layer + const int32_t il_reuse = reuse ? reuse(il) : -1; + + if (il_reuse >= 0) { + // Reuse memory from another layer + auto it = map_layer_ids.find(il_reuse); + GGML_ASSERT(it != map_layer_ids.end() && "layer to reuse not found"); + map_layer_ids[il] = it->second; + continue; + } + + kv_layer layer; + layer.il = il; + + // Initialize block storage + layer.blocks.resize(num_blocks); + for (uint32_t i = 0; i < num_blocks; ++i) { + layer.blocks[i].id = i; + layer.blocks[i].is_free = true; + layer.blocks[i].ref_count = 0; + } + + // Add to layer list + const int32_t il_kv = static_cast(layers.size()); + layers.push_back(std::move(layer)); + map_layer_ids[il] = il_kv; + } + + // Initialize free block list + for (uint32_t i = 0; i < num_blocks; ++i) { + free_blocks.push_back(i); + } + + if (debug > 0) { + fprintf(stderr, "%s: created %zu layers with %u blocks each\n", + __func__, layers.size(), num_blocks); + fprintf(stderr, "%s: map_layer_ids contains %zu entries:\n", __func__, map_layer_ids.size()); + for (const auto & [il, il_kv] : map_layer_ids) { + fprintf(stderr, "%s: layer %d -> kv_layer %d\n", __func__, il, il_kv); + } + } + + // Allocate tensor memory for blocks + const int32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + // const int32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // unused for now + const int32_t n_head_kv = hparams.n_head_kv(); + + // Create context map for different buffer types + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // Allocate space for: + // - 2 base tensors per layer (k_all_blocks, v_all_blocks) + // - 2 * num_blocks view tensors per layer (k_data, v_data for each block) + // Total: layers.size() * 2 * (1 + num_blocks) + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*layers.size()*(1 + num_blocks)*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map.emplace(buft, ctx); + return ctx; + } + return it->second.get(); + }; + + // Create tensors for each layer + for (auto & layer : layers) { + const int32_t il = layer.il; + + // Determine buffer type (CPU or GPU) + bool offload = model.dev_layer(il) != nullptr; + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + } + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for paged kv cache"); + } + + // Create tensors for all blocks in this layer + // Shape: [num_blocks, num_kv_heads, head_size, block_size] + // This matches the expected layout for PagedAttention CUDA kernels + const int64_t head_size = n_embd_k_gqa / n_head_kv; + layer.k_all_blocks = ggml_new_tensor_4d(ctx, type_k, num_blocks, n_head_kv, head_size, block_size); + layer.v_all_blocks = ggml_new_tensor_4d(ctx, type_v, num_blocks, n_head_kv, head_size, block_size); + + ggml_format_name(layer.k_all_blocks, "paged_cache_k_l%d", il); + ggml_format_name(layer.v_all_blocks, "paged_cache_v_l%d", il); + + // Update individual block pointers to reference parts of the contiguous tensor + for (uint32_t i = 0; i < num_blocks; ++i) { + // Create views into the all_blocks tensors + // Each block is a slice along dimension 0: [num_kv_heads, head_size, block_size] + // With layout [num_blocks, num_kv_heads, head_size, block_size], we slice the first dim + const size_t offset = i * layer.k_all_blocks->nb[0]; + layer.blocks[i].k_data = ggml_view_3d(ctx, layer.k_all_blocks, + n_head_kv, head_size, block_size, + layer.k_all_blocks->nb[1], layer.k_all_blocks->nb[2], offset); + layer.blocks[i].v_data = ggml_view_3d(ctx, layer.v_all_blocks, + n_head_kv, head_size, block_size, + layer.v_all_blocks->nb[1], layer.v_all_blocks->nb[2], offset); + } + } + + // Allocate buffers for all contexts + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for paged kv cache"); + } + + if (debug > 0) { + fprintf(stderr, "%s: %10s paged KV buffer size = %8.2f MiB\n", __func__, + ggml_backend_buffer_name(buf), + ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + } + + // Clear buffer to avoid NaN values + ggml_backend_buffer_clear(buf, 0); + + // Store context and buffer pair + ctxs_bufs.emplace_back(std::move(ctx), buf); + } +} + +// +// llama_memory_i interface implementation +// + +llama_memory_context_ptr llama_kv_cache_paged::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(n_ubatch); + GGML_UNUSED(embd_all); + + const auto & batch = balloc.get_batch(); + + if (debug > 0) { + fprintf(stderr, "%s: processing batch with %d tokens\n", + __func__, batch.n_tokens); + fprintf(stderr, "%s: current state: %zu sequences, %zu free blocks\n", + __func__, block_tables.size(), free_blocks.size()); + } + + // Process each token position to ensure blocks are allocated + for (int i = 0; i < batch.n_tokens; ++i) { + // Handle null arrays with defaults: + // - n_seq_id defaults to 1 + // - seq_id defaults to 0 + // - pos defaults to sequential (i) + const int n_seqs = batch.n_seq_id ? batch.n_seq_id[i] : 1; + + if (debug > 1) { + const llama_pos pos_debug = batch.pos ? batch.pos[i] : i; + fprintf(stderr, "%s: token %d: n_seqs=%d, pos=%d\n", + __func__, i, n_seqs, pos_debug); + } + + for (int j = 0; j < n_seqs; ++j) { + const llama_seq_id seq_id = batch.seq_id ? batch.seq_id[i][j] : 0; + const llama_pos pos = batch.pos ? batch.pos[i] : i; + + if (debug > 1) { + fprintf(stderr, "%s: seq_id=%d, pos=%d\n", + __func__, seq_id, pos); + } + + // Check if this sequence needs blocks + auto & blocks = block_tables[seq_id]; + auto & meta = seq_meta[seq_id]; + + // Calculate required blocks for this position + const uint32_t required_blocks = (pos + block_size) / block_size; + + if (debug > 1) { + fprintf(stderr, "%s: current blocks: %zu, required: %u\n", + __func__, blocks.size(), required_blocks); + } + + // Allocate more blocks if needed + while (blocks.size() < required_blocks) { + uint32_t block_id = allocate_block(); + if (block_id == UINT32_MAX) { + fprintf(stderr, "%s: ERROR: failed to allocate block for seq %d at pos %d\n", + __func__, seq_id, pos); + fprintf(stderr, "%s: ERROR: free_blocks.size()=%zu, block_tables.size()=%zu\n", + __func__, free_blocks.size(), block_tables.size()); + return llama_memory_context_ptr( + new llama_kv_cache_paged_context(LLAMA_MEMORY_STATUS_FAILED_PREPARE)); + } + blocks.push_back(block_id); + + if (debug > 1) { + fprintf(stderr, "%s: allocated block %u for seq %d (total blocks: %zu, free remaining: %zu)\n", + __func__, block_id, seq_id, blocks.size(), free_blocks.size()); + } + } + + // Update sequence metadata + if (meta.pos_min < 0 || pos < meta.pos_min) { + meta.pos_min = pos; + } + if (pos > meta.pos_max) { + meta.pos_max = pos; + } + meta.length = static_cast(meta.pos_max - meta.pos_min + 1); + } + } + + // Populate out_ids based on batch.logits + // This is required for llama_context to properly track which tokens produce outputs + auto & out_ids = balloc.get_out_ids(); + out_ids.clear(); + for (int i = 0; i < batch.n_tokens; ++i) { + // batch.logits should have been populated by balloc.prepare() + // If logits[i] is non-zero, this token should produce output + if (batch.logits && batch.logits[i]) { + out_ids.push_back(i); + } + } + + if (debug > 0) { + fprintf(stderr, "%s: batch initialization complete, %zu outputs\n", + __func__, out_ids.size()); + } + + return llama_memory_context_ptr(new llama_kv_cache_paged_context(this)); +} + +llama_memory_context_ptr llama_kv_cache_paged::init_full() { + if (debug > 0) { + fprintf(stderr, "%s: creating context for init_full\n", __func__); + } + + // Return context initialized with this paged cache + auto ctx = new llama_kv_cache_paged_context(this); + + if (debug > 0) { + fprintf(stderr, "%s: context created at %p, creating unique_ptr\n", __func__, (void*)ctx); + } + + llama_memory_context_ptr result(ctx); + + if (debug > 0) { + fprintf(stderr, "%s: unique_ptr created, returning\n", __func__); + } + + return result; +} + +llama_memory_context_ptr llama_kv_cache_paged::init_update( + llama_context * lctx, + bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + // TODO: Implement update initialization + return llama_memory_context_ptr( + new llama_kv_cache_paged_context(LLAMA_MEMORY_STATUS_NO_UPDATE)); +} + +bool llama_kv_cache_paged::get_can_shift() const { + // PagedAttention doesn't support context shifting + // (blocks are allocated independently) + return false; +} + +void llama_kv_cache_paged::clear(bool data) { + GGML_UNUSED(data); + // Free all block tables + block_tables.clear(); + seq_meta.clear(); + + // Reset all blocks to free state + for (auto & layer : layers) { + for (auto & block : layer.blocks) { + block.ref_count = 0; + block.is_free = true; + } + } + + // Rebuild free block list + free_blocks.clear(); + for (uint32_t i = 0; i < num_blocks; ++i) { + free_blocks.push_back(i); + } + + if (debug > 0) { + fprintf(stderr, "%s: cleared paged KV cache\n", __func__); + } +} + +bool llama_kv_cache_paged::seq_rm( + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + // Remove tokens in range [p0, p1) from sequence + if (debug > 0) { + fprintf(stderr, "%s: called with seq_id=%d, p0=%d, p1=%d\n", + __func__, seq_id, p0, p1); + } + + auto it = block_tables.find(seq_id); + if (it == block_tables.end()) { + // Sequence doesn't exist - already cleared, return true + if (debug > 0) { + fprintf(stderr, "%s: sequence %d doesn't exist, already cleared\n", + __func__, seq_id); + } + return true; + } + + // Normalize parameters: p1 < 0 means "to the end" + if (p0 < 0) { + p0 = 0; + } + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // Get sequence metadata + auto meta_it = seq_meta.find(seq_id); + if (meta_it == seq_meta.end()) { + // No metadata - sequence hasn't been used yet, treat as full removal + auto & blocks = it->second; + for (uint32_t block_id : blocks) { + free_block(block_id); + } + + block_tables.erase(it); + + if (debug > 0) { + fprintf(stderr, "%s: removed sequence %d without metadata (%zu blocks freed)\n", + __func__, seq_id, blocks.size()); + } + + return true; + } + + const auto & meta = meta_it->second; + + // Check if we're removing the entire sequence (from start) + // This includes: removing from position 0, or removing from before/at the minimum position + // We also treat removal from an uninitialized sequence (pos_min == -1) as full removal when p0 == 0 + bool remove_from_start = (p0 == 0) || (p0 <= meta.pos_min) || (meta.pos_min == -1 && p0 == 0); + + if (remove_from_start) { + // Removing from the beginning - clear entire sequence + auto & blocks = it->second; + for (uint32_t block_id : blocks) { + free_block(block_id); + } + + block_tables.erase(it); + seq_meta.erase(seq_id); + + if (debug > 0) { + fprintf(stderr, "%s: removed entire sequence %d (%zu blocks freed, p0=%d, pos_min=%d)\n", + __func__, seq_id, blocks.size(), p0, meta.pos_min); + } + + return true; + } + + // Partial removal from the middle/end is not yet supported + // This would require tracking which blocks are partially used + if (debug > 0) { + fprintf(stderr, "%s: partial sequence removal (p0=%d, p1=%d, pos_min=%d) not yet supported in paged cache\n", + __func__, p0, p1, meta.pos_min); + } + return false; +} + +void llama_kv_cache_paged::seq_cp( + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + GGML_UNUSED(p1); + // Copy sequence - in paged attention, this is efficient via block sharing + auto it_src = block_tables.find(seq_id_src); + if (it_src == block_tables.end()) { + return; + } + + // For simplicity, copy entire sequence (ignore p0, p1 for now) + GGML_UNUSED(p0); + auto & src_blocks = it_src->second; + + // Increment reference count on all blocks + for (uint32_t block_id : src_blocks) { + for (auto & layer : layers) { + if (block_id < layer.blocks.size()) { + layer.blocks[block_id].ref_count++; + } + } + } + + // Share the block table + block_tables[seq_id_dst] = src_blocks; + + // Copy metadata + auto it_meta = seq_meta.find(seq_id_src); + if (it_meta != seq_meta.end()) { + seq_meta[seq_id_dst] = it_meta->second; + } + + if (debug > 0) { + fprintf(stderr, "%s: copied sequence %d to %d (%zu blocks shared)\n", + __func__, seq_id_src, seq_id_dst, src_blocks.size()); + } +} + +void llama_kv_cache_paged::seq_keep(llama_seq_id seq_id) { + // Remove all sequences except the specified one + std::vector to_remove; + + for (const auto & entry : block_tables) { + if (entry.first != seq_id) { + to_remove.push_back(entry.first); + } + } + + for (llama_seq_id sid : to_remove) { + seq_rm(sid, -1, -1); + } + + if (debug > 0) { + fprintf(stderr, "%s: kept only sequence %d\n", __func__, seq_id); + } +} + +void llama_kv_cache_paged::seq_add( + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos shift) { + GGML_UNUSED(p1); + // Shift positions in sequence + auto it = seq_meta.find(seq_id); + if (it == seq_meta.end()) { + return; + } + + // Update position metadata + if (p0 >= 0 && it->second.pos_min >= p0) { + it->second.pos_min += shift; + } + if (p0 >= 0 && it->second.pos_max >= p0) { + it->second.pos_max += shift; + } + + if (debug > 0) { + fprintf(stderr, "%s: shifted sequence %d by %d\n", __func__, seq_id, shift); + } +} + +void llama_kv_cache_paged::seq_div( + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + GGML_UNUSED(p0); + GGML_UNUSED(p1); + // Divide positions (used for attention scaling) + // For paged attention, this is mostly metadata-only + auto it = seq_meta.find(seq_id); + if (it == seq_meta.end()) { + return; + } + + if (debug > 0) { + fprintf(stderr, "%s: divided sequence %d positions by %d\n", __func__, seq_id, d); + } + + // Position division would affect logical positioning but not block allocation +} + +llama_pos llama_kv_cache_paged::seq_pos_min(llama_seq_id seq_id) const { + auto it = seq_meta.find(seq_id); + return (it != seq_meta.end()) ? it->second.pos_min : -1; +} + +llama_pos llama_kv_cache_paged::seq_pos_max(llama_seq_id seq_id) const { + auto it = seq_meta.find(seq_id); + return (it != seq_meta.end()) ? it->second.pos_max : -1; +} + +std::map llama_kv_cache_paged::memory_breakdown() const { + // TODO: Implement memory breakdown + return std::map(); +} + +void llama_kv_cache_paged::state_write( + llama_io_write_i & io, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + GGML_UNUSED(io); + GGML_UNUSED(seq_id); + GGML_UNUSED(flags); + // TODO: Implement state serialization + fprintf(stderr, "%s: state saving not yet implemented for paged cache\n", __func__); +} + +void llama_kv_cache_paged::state_read( + llama_io_read_i & io, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + GGML_UNUSED(io); + GGML_UNUSED(seq_id); + GGML_UNUSED(flags); + // TODO: Implement state deserialization + fprintf(stderr, "%s: state loading not yet implemented for paged cache\n", __func__); +} + +// +// PagedAttention specific functions +// + +const std::vector & llama_kv_cache_paged::get_block_table(llama_seq_id seq_id) const { + static const std::vector empty; + auto it = block_tables.find(seq_id); + return (it != block_tables.end()) ? it->second : empty; +} + +std::vector llama_kv_cache_paged::get_seq_lens() const { + std::vector lens; + lens.reserve(seq_meta.size()); + + for (const auto & entry : seq_meta) { + lens.push_back(static_cast(entry.second.length)); + } + + return lens; +} + +ggml_tensor * llama_kv_cache_paged::get_k_blocks(int32_t il) const { + // Map model layer ID to KV cache layer ID + auto it = map_layer_ids.find(il); + if (it == map_layer_ids.end()) { + return nullptr; + } + + const int32_t il_kv = it->second; + if (il_kv < 0 || il_kv >= static_cast(layers.size())) { + return nullptr; + } + + return layers[il_kv].k_all_blocks; +} + +ggml_tensor * llama_kv_cache_paged::get_v_blocks(int32_t il) const { + // Map model layer ID to KV cache layer ID + auto it = map_layer_ids.find(il); + if (it == map_layer_ids.end()) { + return nullptr; + } + + const int32_t il_kv = it->second; + if (il_kv < 0 || il_kv >= static_cast(layers.size())) { + return nullptr; + } + + return layers[il_kv].v_all_blocks; +} + +ggml_tensor * llama_kv_cache_paged::build_block_tables_tensor(ggml_context * ctx) const { + // Build block tables tensor for all active sequences + // Shape: [max_blocks_per_seq, n_seqs] + + // During graph building (before any sequences exist), use default sizes + size_t max_blocks; + size_t n_seqs_actual; + + if (block_tables.empty()) { + // Use defaults for graph building + n_seqs_actual = n_seq_max; + // Estimate max blocks based on context size and block size + // Assume each sequence could use the full context + max_blocks = (4096 + block_size - 1) / block_size; // default n_ctx = 4096 + } else { + // Find maximum number of blocks per sequence + max_blocks = 0; + for (const auto & [seq_id, blocks] : block_tables) { + max_blocks = std::max(max_blocks, blocks.size()); + } + n_seqs_actual = block_tables.size(); + } + + // Create tensor + ggml_tensor * tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_blocks, n_seqs_actual); + ggml_set_input(tensor); + + // Fill with block IDs (will be done during set_input) + // For now, the structure is created + return tensor; +} + +ggml_tensor * llama_kv_cache_paged::build_seq_lens_tensor(ggml_context * ctx) const { + // Build sequence lengths tensor + // Shape: [n_seqs] + + // During graph building (before any sequences exist), use default size + const size_t n_seqs = seq_meta.empty() ? n_seq_max : seq_meta.size(); + + // Create tensor + ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_set_input(tensor); + + return tensor; +} + +// +// Block management (private) +// + +uint32_t llama_kv_cache_paged::allocate_block() { + if (free_blocks.empty()) { + fprintf(stderr, "%s: ERROR: out of free blocks!\n", __func__); + return UINT32_MAX; + } + + uint32_t block_id = free_blocks.back(); + free_blocks.pop_back(); + + // Mark block as allocated in all layers + for (auto & layer : layers) { + if (block_id < layer.blocks.size()) { + layer.blocks[block_id].is_free = false; + layer.blocks[block_id].ref_count = 1; + } + } + + if (debug > 1) { + fprintf(stderr, "%s: allocated block %u (%zu free remaining)\n", + __func__, block_id, free_blocks.size()); + } + + return block_id; +} + +void llama_kv_cache_paged::free_block(uint32_t block_id) { + if (block_id >= num_blocks) { + return; + } + + // Decrement reference count + for (auto & layer : layers) { + if (block_id < layer.blocks.size()) { + auto & block = layer.blocks[block_id]; + + if (block.ref_count > 0) { + block.ref_count--; + } + + // Free block if reference count reaches zero + if (block.ref_count == 0 && !block.is_free) { + block.is_free = true; + free_blocks.push_back(block_id); + + if (debug > 1) { + fprintf(stderr, "%s: freed block %u (%zu free blocks total)\n", + __func__, block_id, free_blocks.size()); + } + } + } + } +} + +void llama_kv_cache_paged::allocate_blocks_for_sequence( + llama_seq_id seq_id, + uint32_t num_tokens) { + // Calculate number of blocks needed + uint32_t num_blocks_needed = (num_tokens + block_size - 1) / block_size; + + if (debug > 0) { + fprintf(stderr, "%s: allocating %u blocks for sequence %d (%u tokens)\n", + __func__, num_blocks_needed, seq_id, num_tokens); + } + + // Allocate blocks + auto & blocks = block_tables[seq_id]; + blocks.reserve(num_blocks_needed); + + for (uint32_t i = 0; i < num_blocks_needed; ++i) { + uint32_t block_id = allocate_block(); + if (block_id == UINT32_MAX) { + fprintf(stderr, "%s: ERROR: failed to allocate block %u/%u for sequence %d\n", + __func__, i, num_blocks_needed, seq_id); + return; + } + blocks.push_back(block_id); + } + + // Update sequence metadata + auto & meta = seq_meta[seq_id]; + meta.length = num_tokens; + meta.pos_min = 0; + meta.pos_max = static_cast(num_tokens - 1); +} + +// +// Helper functions (private) +// + +size_t llama_kv_cache_paged::total_size() const { + return size_k_bytes() + size_v_bytes(); +} + +size_t llama_kv_cache_paged::size_k_bytes() const { + // TODO: Calculate actual memory size based on tensor layouts + return 0; +} + +size_t llama_kv_cache_paged::size_v_bytes() const { + // TODO: Calculate actual memory size based on tensor layouts + return 0; +} + +void llama_kv_cache_paged::populate_block_tables_tensor(ggml_tensor * tensor) const { + if (!tensor || !tensor->data) { + fprintf(stderr, "%s: ERROR: tensor is null or has no data\n", __func__); + return; + } + + if (debug > 0) { + fprintf(stderr, "%s: populating block tables tensor [%lld, %lld]\n", + __func__, (long long)tensor->ne[0], (long long)tensor->ne[1]); + } + + // Tensor layout: [max_blocks_per_seq, n_seqs] + const int64_t max_blocks = tensor->ne[0]; + const int64_t n_seqs = tensor->ne[1]; + + // Initialize all entries to 0 (invalid block ID) + int32_t * data = reinterpret_cast(tensor->data); + memset(data, 0, ggml_nbytes(tensor)); + + // Fill in block IDs for each active sequence + int seq_idx = 0; + for (const auto & entry : block_tables) { + if (seq_idx >= n_seqs) { + fprintf(stderr, "%s: WARNING: more sequences than tensor space (%d >= %lld)\n", + __func__, seq_idx, (long long)n_seqs); + break; + } + + const auto & blocks = entry.second; + const int64_t num_blocks = std::min(static_cast(blocks.size()), max_blocks); + + // Copy block IDs for this sequence + for (int64_t i = 0; i < num_blocks; ++i) { + data[seq_idx * max_blocks + i] = static_cast(blocks[i]); + } + + if (debug > 1) { + fprintf(stderr, "%s: seq %d: %lld blocks [", __func__, entry.first, (long long)num_blocks); + for (int64_t i = 0; i < std::min(num_blocks, (int64_t)4); ++i) { + fprintf(stderr, "%d%s", blocks[i], i < num_blocks - 1 ? ", " : ""); + } + if (num_blocks > 4) fprintf(stderr, "..."); + fprintf(stderr, "]\n"); + } + + seq_idx++; + } + + if (debug > 0) { + fprintf(stderr, "%s: populated %d sequences\n", __func__, seq_idx); + } +} + +void llama_kv_cache_paged::populate_seq_lens_tensor(ggml_tensor * tensor) const { + if (!tensor || !tensor->data) { + fprintf(stderr, "%s: ERROR: tensor is null or has no data\n", __func__); + return; + } + + if (debug > 0) { + fprintf(stderr, "%s: populating seq_lens tensor [%lld]\n", + __func__, (long long)tensor->ne[0]); + } + + // Tensor layout: [n_seqs] + const int64_t n_seqs = tensor->ne[0]; + + // Initialize all entries to 0 + int32_t * data = reinterpret_cast(tensor->data); + memset(data, 0, ggml_nbytes(tensor)); + + // Fill in sequence lengths + int seq_idx = 0; + for (const auto & entry : block_tables) { + if (seq_idx >= n_seqs) { + fprintf(stderr, "%s: WARNING: more sequences than tensor space (%d >= %lld)\n", + __func__, seq_idx, (long long)n_seqs); + break; + } + + const llama_seq_id seq_id = entry.first; + + // Get length from metadata if available, otherwise calculate from block table + uint32_t seq_len = 0; + auto meta_it = seq_meta.find(seq_id); + if (meta_it != seq_meta.end()) { + seq_len = meta_it->second.length; + } else { + // Fallback: use number of blocks * block_size + seq_len = static_cast(entry.second.size() * block_size); + } + + data[seq_idx] = static_cast(seq_len); + + if (debug > 1) { + fprintf(stderr, "%s: seq %d: length = %u\n", __func__, seq_id, seq_len); + } + + seq_idx++; + } + + if (debug > 0) { + fprintf(stderr, "%s: populated %d sequence lengths\n", __func__, seq_idx); + } +} diff --git a/src/llama-kv-cache-paged.h b/src/llama-kv-cache-paged.h new file mode 100644 index 00000000000..7a0f7d4eba9 --- /dev/null +++ b/src/llama-kv-cache-paged.h @@ -0,0 +1,245 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-memory.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// Forward declare the paged KV cache class +class llama_kv_cache_paged; + +// +// llama_kv_cache_paged_context - Context for PagedAttention operations +// +// NOTE: The graph building code uses static_cast to llama_kv_cache_context* and calls +// methods specific to that class. To avoid crashes, we inherit from llama_kv_cache_context +// and provide stub implementations that will be replaced with proper PagedAttention logic. +// +class llama_kv_cache_context; // forward declare + +class llama_kv_cache_paged_context : public llama_memory_context_i { +public: + llama_kv_cache_paged_context(llama_memory_status status); + llama_kv_cache_paged_context(llama_kv_cache_paged * kv_paged); + + virtual ~llama_kv_cache_paged_context() { + fprintf(stderr, "llama_kv_cache_paged_context::~llama_kv_cache_paged_context() called\n"); + } + + // llama_memory_context_i interface + bool next() override { + fprintf(stderr, "llama_kv_cache_paged_context::next() called\n"); + return false; + } + bool apply() override { + fprintf(stderr, "llama_kv_cache_paged_context::apply() called\n"); + return true; + } + llama_memory_status get_status() const override { + fprintf(stderr, "llama_kv_cache_paged_context::get_status() called\n"); + return status; + } + const llama_ubatch & get_ubatch() const override { + fprintf(stderr, "llama_kv_cache_paged_context::get_ubatch() called\n"); + return ubatch; + } + + // Get the underlying paged cache + llama_kv_cache_paged * get_kv_paged() const { return kv_paged; } + + // Stub methods to match llama_kv_cache_context interface + // These are called by graph building code via static_cast + // TODO: Implement proper PagedAttention versions + uint32_t get_n_kv() const { return 0; } + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + llama_memory_status status; + llama_kv_cache_paged * kv_paged; + llama_ubatch ubatch; // dummy ubatch +}; + +// +// llama_kv_cache_paged - PagedAttention KV cache implementation +// +// This cache divides memory into fixed-size blocks (similar to virtual memory paging) +// to reduce fragmentation and enable efficient memory sharing between sequences. +// +// Key concepts: +// - Block: Fixed-size unit of KV cache storage (e.g., 16 tokens) +// - Block Table: Maps logical token positions to physical blocks per sequence +// - Block Pool: Manages allocation/deallocation of physical blocks +// + +class llama_kv_cache_paged : public llama_memory_i { +public: + // Physical block in memory containing KV data for multiple tokens + struct block { + uint32_t id; // unique block ID + ggml_tensor * k_data; // K cache data for this block + ggml_tensor * v_data; // V cache data for this block + uint32_t ref_count; // reference count for block sharing + bool is_free; // whether block is in free pool + + block() : id(0), k_data(nullptr), v_data(nullptr), ref_count(0), is_free(true) {} + }; + + llama_kv_cache_paged( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t block_size, // tokens per block + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_paged() = default; + + // + // llama_memory_i interface + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // PagedAttention specific API + // + + // Get block size (tokens per block) + uint32_t get_block_size() const { return block_size; } + + // Get total number of blocks + uint32_t get_num_blocks() const { return num_blocks; } + + // Get number of free blocks + uint32_t get_num_free_blocks() const { return static_cast(free_blocks.size()); } + + // Get block table for a sequence (maps token positions to block IDs) + const std::vector & get_block_table(llama_seq_id seq_id) const; + + // Get sequence lengths for all sequences + std::vector get_seq_lens() const; + + // Access to block data tensors (for CUDA kernels) + ggml_tensor * get_k_blocks(int32_t il) const; + ggml_tensor * get_v_blocks(int32_t il) const; + + // Get block tables tensor (for CUDA kernels) + ggml_tensor * build_block_tables_tensor(ggml_context * ctx) const; + + // Get sequence lengths tensor (for CUDA kernels) + ggml_tensor * build_seq_lens_tensor(ggml_context * ctx) const; + + // Populate tensor data for CUDA kernels + void populate_block_tables_tensor(ggml_tensor * tensor) const; + void populate_seq_lens_tensor(ggml_tensor * tensor) const; + +private: +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-private-field" +#endif + const llama_model & model; + const llama_hparams & hparams; + + // Block storage per layer + struct kv_layer { + uint32_t il; // layer index in model + + // All blocks for this layer (both used and free) + std::vector blocks; + + // Contiguous tensors holding all blocks for this layer + // Shape: [num_blocks, block_size, num_kv_heads, head_size] + ggml_tensor * k_all_blocks = nullptr; + ggml_tensor * v_all_blocks = nullptr; + }; + + const ggml_type type_k; // data type for K cache + const ggml_type type_v; // data type for V cache + const uint32_t n_seq_max = 1; // max number of sequences +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + const uint32_t block_size = 16; // tokens per block (must be power of 2) + const uint32_t num_blocks = 0; // total number of blocks + + // env: LLAMA_KV_CACHE_DEBUG + int debug = 0; + + // ggml contexts for the KV cache along with allocated backend buffers + std::vector> ctxs_bufs; + + // Block management + std::vector free_blocks; // IDs of free blocks + + // Per-sequence block tables (seq_id -> list of block IDs) + std::unordered_map> block_tables; + + // Per-sequence metadata + struct seq_metadata { + llama_pos pos_min = -1; // minimum position in sequence + llama_pos pos_max = -1; // maximum position in sequence + uint32_t length = 0; // sequence length in tokens + }; + std::unordered_map seq_meta; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // Block management functions + uint32_t allocate_block(); + void free_block(uint32_t block_id); + void allocate_blocks_for_sequence(llama_seq_id seq_id, uint32_t num_tokens); + + // Helper functions + size_t total_size() const; + size_t size_k_bytes() const; + size_t size_v_bytes() const; +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2a545531a9..e2ec2545684 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-paged.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -7069,6 +7070,21 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, nullptr, reuse); + } else if (cparams.use_paged_attention) { + // Use PagedAttention cache + LLAMA_LOG_INFO("%s: creating paged KV cache\n", __func__); + GGML_ASSERT(!hparams.is_swa_any() && "PagedAttention does not support SWA yet"); + + res = new llama_kv_cache_paged( + *this, + params.type_k, + params.type_v, + cparams.n_ctx_seq, + cparams.n_seq_max, + 16, // block_size (16 tokens per block) + nullptr, + reuse); + LLAMA_LOG_INFO("%s: paged KV cache created successfully\n", __func__); } else { GGML_ASSERT(!hparams.is_swa_any()); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index ab7fd5d0508..0cfcbbed23c 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -14,7 +14,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn_base = build_attn_inp(); + auto * inp_attn = dynamic_cast(inp_attn_base); + auto * inp_attn_paged = dynamic_cast(inp_attn_base); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -80,9 +82,15 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_normed", il); cb(Kcur, "Kcur_normed", il); } - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + if (inp_attn_paged) { + cur = build_attn(inp_attn_paged, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } else { + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2bf3924df90..53b837a08ca 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2050,9 +2050,12 @@ struct server_context_impl { SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing KV cache\n", p0); - clear_slot(slot); + // For paged KV cache, partial removal may not be supported yet + // In this case, clear the entire sequence to avoid stale KV data + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); // there is no common part left slot.n_prompt_tokens_cache = 0;