Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"--pagedattention"},
"enable PagedAttention for KV cache (experimental, requires CUDA)",
[](common_params & params) {
fprintf(stderr, "DEBUG: --pagedattention flag parsed, setting params.use_paged_attention = true\n");
params.use_paged_attention = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
Expand Down
4 changes: 4 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
struct llama_context_params common_context_params_to_llama(const common_params & params) {
auto cparams = llama_context_default_params();

fprintf(stderr, "DEBUG common_context_params_to_llama: params.use_paged_attention = %s\n",
params.use_paged_attention ? "true" : "false");

cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_batch = params.n_batch;
Expand All @@ -1275,6 +1278,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.kv_unified = params.kv_unified;
cparams.use_paged_attention = params.use_paged_attention;

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ struct common_params {
bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
bool use_paged_attention = false; // enable PagedAttention (experimental, requires CUDA)

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
Expand Down
28 changes: 28 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ extern "C" {

GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_PAGED_ATTENTION,
GGML_OP_PAGED_CPY,
GGML_OP_SSM_CONV,
GGML_OP_SSM_SCAN,
GGML_OP_WIN_PART,
Expand Down Expand Up @@ -2312,6 +2314,32 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * sinks);

// PagedAttention (paged KV cache attention)
// q: [n_tokens, n_heads, head_size]
// k_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
// v_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
// block_tables: [n_seqs, max_blocks_per_seq] (int32)
// seq_lens: [n_seqs] (int32)
GGML_API struct ggml_tensor * ggml_paged_attention(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k_cache,
struct ggml_tensor * v_cache,
struct ggml_tensor * block_tables,
struct ggml_tensor * seq_lens,
int32_t block_size,
float scale);

// Copy K/V data to paged cache blocks (similar to vLLM's reshape_and_cache)
// kv_cur: [head_size, n_heads, n_tokens] - K or V current
// kv_cache: [num_blocks, n_kv_heads, head_size, block_size] - paged K or V cache
// slot_idxs: [n_tokens] (int32) - cache slot index for each token
GGML_API struct ggml_tensor * ggml_paged_cpy(
struct ggml_context * ctx,
struct ggml_tensor * kv_cur,
struct ggml_tensor * kv_cache,
struct ggml_tensor * slot_idxs);

// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_PAGED_ATTENTION:
{
// nop (CUDA-only operation)
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include "ggml-cuda/opt-step-sgd.cuh"
#include "ggml-cuda/out-prod.cuh"
#include "ggml-cuda/pad.cuh"
#include "ggml-cuda/paged-attention-backend.cuh"
#include "ggml-cuda/paged-cpy.cuh"
#include "ggml-cuda/pool2d.cuh"
#include "ggml-cuda/quantize.cuh"
#include "ggml-cuda/rope.cuh"
Expand Down Expand Up @@ -2719,6 +2721,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_OPT_STEP_SGD:
ggml_cuda_opt_step_sgd(ctx, dst);
break;
case GGML_OP_PAGED_ATTENTION:
ggml_cuda_op_paged_attention(ctx, dst);
break;
case GGML_OP_PAGED_CPY:
ggml_cuda_op_paged_cpy(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
Expand Down Expand Up @@ -4564,6 +4572,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
case GGML_OP_PAGED_ATTENTION:
return ggml_cuda_can_paged_attention(op);
case GGML_OP_PAGED_CPY:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
default:
Expand Down
223 changes: 223 additions & 0 deletions ggml/src/ggml-cuda/paged-attention-backend.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/**
* GGML CUDA Backend for PagedAttention
*
* This file provides the CUDA backend implementation for the GGML_OP_PAGED_ATTENTION operation.
* It bridges GGML's operation framework with the PagedAttention CUDA kernels.
*
* NOTE: PagedAttention is currently experimental and only supported on CUDA.
* MUSA support is disabled due to compiler compatibility issues.
*/

// PagedAttention is not yet supported on MUSA
#ifndef GGML_USE_MUSA

#include "common.cuh"
#include "paged-attention.cuh"
#include "paged-attention-backend.cuh"

// Extract parameters from GGML tensor
static void ggml_cuda_op_paged_attention_get_params(
const ggml_tensor * dst,
float * scale,
int32_t * block_size) {

const float * params = (const float *)dst->op_params;
*scale = params[0];
*block_size = (int32_t)params[1];
}

// Main CUDA backend function for PagedAttention
void ggml_cuda_op_paged_attention(
ggml_backend_cuda_context & ctx,
ggml_tensor * dst) {

const ggml_tensor * q = dst->src[0]; // query
const ggml_tensor * k_cache = dst->src[1]; // key cache (paged)
const ggml_tensor * v_cache = dst->src[2]; // value cache (paged)
const ggml_tensor * block_tables = dst->src[3]; // block tables
const ggml_tensor * seq_lens = dst->src[4]; // sequence lengths
const ggml_tensor * alibi_slopes = dst->src[5]; // optional ALiBi slopes (can be nullptr)

// Extract parameters
float scale;
int32_t block_size;
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);

// Get tensor dimensions
const int64_t head_size = q->ne[0];
const int64_t n_heads = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];

const int64_t n_kv_heads = k_cache->ne[2];
const int64_t num_blocks = k_cache->ne[0];

const int64_t max_blocks_per_seq = block_tables->ne[0];

// Validate tensor dimensions
GGML_ASSERT(n_tokens > 0 && "Number of query tokens must be positive");
GGML_ASSERT(n_seqs > 0 && "Number of sequences must be positive");
GGML_ASSERT(num_blocks > 0 && "Number of KV cache blocks must be positive");
GGML_ASSERT(max_blocks_per_seq > 0 && "Max blocks per sequence must be positive");

// Validate that we have enough blocks available
// Note: This is a soft check - actual usage depends on sequence lengths
GGML_ASSERT(num_blocks >= max_blocks_per_seq &&
"Total number of blocks should be >= max blocks per sequence");

// For PagedAttention, typically we have one query per sequence (decode mode)
// or multiple queries per sequence (prefill mode)
GGML_ASSERT(n_tokens <= n_seqs * 1024 &&
"Number of tokens seems unusually large relative to batch size");

// Get pointers
void * out_ptr = dst->data;
const void * q_ptr = q->data;
const void * k_cache_ptr = k_cache->data;
const void * v_cache_ptr = v_cache->data;
const int32_t * block_tables_ptr = (const int32_t *)block_tables->data;
const int32_t * seq_lens_ptr = (const int32_t *)seq_lens->data;

// Debug: Check for null pointers
GGML_ASSERT(out_ptr != nullptr && "Output pointer is null");
GGML_ASSERT(q_ptr != nullptr && "Query pointer is null");
GGML_ASSERT(k_cache_ptr != nullptr && "K cache pointer is null");
GGML_ASSERT(v_cache_ptr != nullptr && "V cache pointer is null");
GGML_ASSERT(block_tables_ptr != nullptr && "Block tables pointer is null");
GGML_ASSERT(seq_lens_ptr != nullptr && "Sequence lengths pointer is null");

// Get ALiBi slopes pointer if provided
const float * alibi_slopes_ptr = nullptr;
if (alibi_slopes != nullptr) {
// ALiBi slopes should be a 1D tensor with one slope per attention head
GGML_ASSERT(alibi_slopes->type == GGML_TYPE_F32 &&
"ALiBi slopes must be float32");
GGML_ASSERT(alibi_slopes->ne[0] == n_heads &&
"ALiBi slopes tensor must have one value per head");
alibi_slopes_ptr = (const float *)alibi_slopes->data;
}

// Calculate max sequence length (needed to decide V1 vs V2)
int max_seq_len = 0;
for (int i = 0; i < n_seqs; i++) {
if (seq_lens_ptr[i] > max_seq_len) {
max_seq_len = seq_lens_ptr[i];
}
}

// Get CUDA stream
cudaStream_t stream = ctx.stream();

// Decide whether to use V1 or V2
const bool use_v1 = ggml_cuda_paged_attention::should_use_v1(
max_seq_len, n_seqs, n_heads);

// Launch appropriate kernel
if (use_v1) {
ggml_cuda_paged_attention::paged_attention_v1_launcher(
out_ptr,
q_ptr,
k_cache_ptr,
v_cache_ptr,
n_seqs,
n_heads,
n_kv_heads,
head_size,
block_size,
max_blocks_per_seq,
block_tables_ptr,
seq_lens_ptr,
max_seq_len,
scale,
alibi_slopes_ptr,
q->type,
k_cache->type,
stream);
} else {
ggml_cuda_paged_attention::paged_attention_v2_launcher(
out_ptr,
q_ptr,
k_cache_ptr,
v_cache_ptr,
n_seqs,
n_heads,
n_kv_heads,
head_size,
block_size,
max_blocks_per_seq,
block_tables_ptr,
seq_lens_ptr,
max_seq_len,
scale,
alibi_slopes_ptr,
q->type,
k_cache->type,
ctx.pool(),
stream);
}

// Check for errors
CUDA_CHECK(cudaGetLastError());
}

// Check if PagedAttention is supported for given configuration
bool ggml_cuda_can_paged_attention(const ggml_tensor * dst) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k_cache = dst->src[1];

// Check data types
if (q->type != GGML_TYPE_F16 && q->type != GGML_TYPE_F32) {
return false;
}

if (k_cache->type != GGML_TYPE_F16 && k_cache->type != GGML_TYPE_F32) {
return false;
}

// Check head size is supported
const int64_t head_size = q->ne[0];
const int supported_head_sizes[] = {32, 64, 80, 96, 112, 120, 128, 192, 256};
bool head_size_supported = false;

for (int hs : supported_head_sizes) {
if (head_size == hs) {
head_size_supported = true;
break;
}
}

if (!head_size_supported) {
return false;
}

// Extract block size and check it's supported
float scale;
int32_t block_size;
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);

if (block_size != 8 && block_size != 16 && block_size != 32) {
return false;
}

return true;
}

#else // GGML_USE_MUSA

// Stub implementations for MUSA (PagedAttention not yet supported)
#include "common.cuh"

void ggml_cuda_op_paged_attention(
ggml_backend_cuda_context & ctx,
ggml_tensor * dst) {
GGML_UNUSED(ctx);
GGML_UNUSED(dst);
GGML_ABORT("PagedAttention is not yet supported on MUSA");
}

bool ggml_cuda_supports_paged_attention(const ggml_tensor * dst) {
GGML_UNUSED(dst);
return false;
}

#endif // GGML_USE_MUSA
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/paged-attention-backend.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

void ggml_cuda_op_paged_attention(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

bool ggml_cuda_can_paged_attention(const ggml_tensor * dst);
Loading
Loading