diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7901b71e22..60e813e921 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -233,6 +233,62 @@ vec_apply_llama_rope_cos_sin_interleave_reuse_half(const T* x, const vec_t +__device__ __forceinline__ void scale_store_partial_chunk(const DType* in_ptr, QuantType* out_ptr, + uint32_t lane_elem_offset, + uint32_t chunk_valid, float scale) { + if (chunk_valid == 0 || lane_elem_offset >= chunk_valid) { + return; + } + vec_t vec; + if (lane_elem_offset + vec_size <= chunk_valid) { + vec.cast_load(in_ptr + lane_elem_offset); + } else { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + uint32_t elem_idx = lane_elem_offset + i; + if (elem_idx < chunk_valid) { + vec_t tmp; + tmp.cast_load(in_ptr + elem_idx); + vec[i] = tmp[0]; + } else { + vec[i] = 0.f; + } + } + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = vec[i] * scale; + } + if (lane_elem_offset + vec_size <= chunk_valid) { + vec.cast_store(out_ptr + lane_elem_offset); + } else { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + uint32_t elem_idx = lane_elem_offset + i; + if (elem_idx < chunk_valid) { + vec_t tmp; + tmp[0] = vec[i]; + tmp.cast_store(out_ptr + elem_idx); + } + } + } +} + template __global__ void BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel( @@ -485,13 +541,12 @@ __global__ void RopeQuantizeKernel( k_nope_out + get_elem_offset_impl(idx, k_head_idx, elem_offset, k_nope_out_stride, k_nope_out_stride_h); - vec_t k_nope_vec; - k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; - } - k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size); + uint32_t chunk_valid = + (elem_offset < no_rope_dim) ? min(rope_chunk_size, no_rope_dim - elem_offset) : 0u; + uint32_t lane_elem_offset = tx * vec_size; + // Handle tail chunks where no_rope_dim is not a multiple of rope_dim. + scale_store_partial_chunk( + k_nope_in_ptr, k_nope_out_ptr, lane_elem_offset, chunk_valid, quant_scale_kv); } else { // Q Non-RoPE processing: num_qo_heads * no_rope_chunks blocks @@ -506,13 +561,12 @@ __global__ void RopeQuantizeKernel( q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, q_nope_out_stride_h); - vec_t q_nope_vec; - q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; - } - q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + uint32_t chunk_valid = + (elem_offset < no_rope_dim) ? min(rope_chunk_size, no_rope_dim - elem_offset) : 0u; + uint32_t lane_elem_offset = tx * vec_size; + // Handle tail chunks where no_rope_dim is not a multiple of rope_dim. + scale_store_partial_chunk( + q_nope_in_ptr, q_nope_out_ptr, lane_elem_offset, chunk_valid, quant_scale_q); } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -993,75 +1047,74 @@ cudaError_t RopeQuantize( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); // Use nested macros for runtime->compile-time dispatch for required constexpr values - DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - constexpr uint32_t vec_size = 32 / sizeof(DType); - constexpr uint32_t bdx = ROPE_DIM / vec_size; - uint32_t num_threads = 128U; - uint32_t bdy = num_threads / bdx; - uint32_t nblks_x = (nnz + bdy - 1) / bdy; - uint32_t rope_chunk_size = rope_dim; - uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; - uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; - uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + - num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; - void* args[] = {(void*)&q_rope_in, - (void*)&k_rope_in, - (void*)&q_nope_in, - (void*)&k_nope_in, - (void*)&q_rope_out, - (void*)&k_rope_out, - (void*)&q_nope_out, - (void*)&k_nope_out, - (void*)&cos_sin_cache, - (void*)&pos_ids, - (void*)&nnz, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&rope_dim, - (void*)&no_rope_dim, - (void*)&q_rope_in_stride_n, - (void*)&q_rope_in_stride_h, - (void*)&q_nope_in_stride_n, - (void*)&q_nope_in_stride_h, - (void*)&q_rope_out_stride_n, - (void*)&q_rope_out_stride_h, - (void*)&q_nope_out_stride_n, - (void*)&q_nope_out_stride_h, - (void*)&k_rope_in_stride, - (void*)&k_rope_in_stride_h, - (void*)&k_nope_in_stride, - (void*)&k_nope_in_stride_h, - (void*)&k_rope_out_stride, - (void*)&k_rope_out_stride_h, - (void*)&k_nope_out_stride, - (void*)&k_nope_out_stride_h, - (void*)&quant_scale_q, - (void*)&quant_scale_kv}; - auto kernel = RopeQuantizeKernel; - dim3 nblks(nblks_x, total_blocks_y); - dim3 nthrs(bdx, bdy); - - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; - cudaLaunchConfig_t config; - config.gridDim = nblks; - config.blockDim = nthrs; - config.stream = stream; - config.dynamicSmemBytes = 0; - config.attrs = attribute; - config.numAttrs = 1; - - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( - &config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, - q_nope_out, k_nope_out, cos_sin_cache, pos_ids, nnz, num_qo_heads, num_kv_heads, rope_dim, - no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, - q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, - q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, - k_nope_in_stride_h, k_rope_out_stride, k_rope_out_stride_h, k_nope_out_stride, - k_nope_out_stride_h, quant_scale_q, quant_scale_kv)); - }); + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; + bdx = std::max(1u, bdx); + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = std::max(1u, num_threads / bdx); + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunk_size = rope_dim; + uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; + void* args[] = {(void*)&q_rope_in, + (void*)&k_rope_in, + (void*)&q_nope_in, + (void*)&k_nope_in, + (void*)&q_rope_out, + (void*)&k_rope_out, + (void*)&q_nope_out, + (void*)&k_nope_out, + (void*)&cos_sin_cache, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rope_dim, + (void*)&no_rope_dim, + (void*)&q_rope_in_stride_n, + (void*)&q_rope_in_stride_h, + (void*)&q_nope_in_stride_n, + (void*)&q_nope_in_stride_h, + (void*)&q_rope_out_stride_n, + (void*)&q_rope_out_stride_h, + (void*)&q_nope_out_stride_n, + (void*)&q_nope_out_stride_h, + (void*)&k_rope_in_stride, + (void*)&k_rope_in_stride_h, + (void*)&k_nope_in_stride, + (void*)&k_nope_in_stride_h, + (void*)&k_rope_out_stride, + (void*)&k_rope_out_stride_h, + (void*)&k_nope_out_stride, + (void*)&k_nope_out_stride_h, + (void*)&quant_scale_q, + (void*)&quant_scale_kv}; + auto kernel = RopeQuantizeKernel; + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( + &config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, + q_nope_out, k_nope_out, cos_sin_cache, pos_ids, nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, k_nope_in_stride_h, + k_rope_out_stride, k_rope_out_stride_h, k_nope_out_stride, k_nope_out_stride_h, + quant_scale_q, quant_scale_kv)); }); return cudaSuccess; @@ -1082,72 +1135,63 @@ cudaError_t RopeQuantizeAppendPagedKVCache( size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { - DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - constexpr uint32_t vec_size = 32 / sizeof(DType); - constexpr uint32_t bdx = ROPE_DIM / vec_size; - uint32_t num_threads = 128U; - uint32_t bdy = num_threads / bdx; - uint32_t nblks_x = (nnz + bdy - 1) / bdy; - uint32_t rope_chunks = 1; - uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; - - // GQA/MHA: Q rope + K rope + K nope + V + Q nope - uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + - num_kv_heads * no_rope_chunks + num_kv_heads + - num_qo_heads * no_rope_chunks; - - dim3 nblks(nblks_x, total_blocks_y); - dim3 nthrs(bdx, bdy); - - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; - cudaLaunchConfig_t config; - config.gridDim = nblks; - config.blockDim = nthrs; - config.stream = stream; - config.dynamicSmemBytes = 0; - config.attrs = attribute; - config.numAttrs = 1; - - auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; - RopeQuantizeAppendPagedKVCacheParams params; - params.nnz = nnz; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.rope_dim = rope_dim; - params.no_rope_dim = no_rope_dim; - params.q_rope_in_stride_n = q_rope_in_stride_n; - params.q_rope_in_stride_h = q_rope_in_stride_h; - params.q_nope_in_stride_n = q_nope_in_stride_n; - params.q_nope_in_stride_h = q_nope_in_stride_h; - params.q_rope_out_stride_n = q_rope_out_stride_n; - params.q_rope_out_stride_h = q_rope_out_stride_h; - params.q_nope_out_stride_n = q_nope_out_stride_n; - params.q_nope_out_stride_h = q_nope_out_stride_h; - params.k_rope_in_stride = k_rope_in_stride; - params.k_rope_in_stride_h = k_rope_in_stride_h; - params.k_nope_in_stride = k_nope_in_stride; - params.k_nope_in_stride_h = k_nope_in_stride_h; - params.v_in_stride = v_in_stride; - params.v_in_stride_h = v_in_stride_h; - params.quant_scale_q = quant_scale_q; - params.quant_scale_kv = quant_scale_kv; - - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, - // inputs - q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, - // q outputs - q_rope_out, q_nope_out, - // cache + indices - paged_kv, batch_indices, positions, - // rope tables - cos_sin_cache, pos_ids, - // params - params)); - }); + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; + bdx = std::max(1u, bdx); + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = std::max(1u, num_threads / bdx); + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_kv_heads + + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h; + params.v_in_stride = v_in_stride; + params.v_in_stride_h = v_in_stride_h; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( + &config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, q_rope_out, q_nope_out, + paged_kv, batch_indices, positions, cos_sin_cache, pos_ids, params)); }); return cudaSuccess; @@ -1166,81 +1210,75 @@ cudaError_t RopeQuantizeAppendPagedMLACache( size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { - DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - constexpr uint32_t vec_size = 32 / sizeof(DType); - constexpr uint32_t bdx = ROPE_DIM / vec_size; - uint32_t num_threads = 128U; - uint32_t bdy = num_threads / bdx; - uint32_t nblks_x = (nnz + bdy - 1) / bdy; - uint32_t rope_chunks = 1; - uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; - - // MLA: Q rope + K rope + K nope + Q nope (no V) - constexpr uint32_t num_kv_heads = 1; - uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + - num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; - - dim3 nblks(nblks_x, total_blocks_y); - dim3 nthrs(bdx, bdy); - - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; - cudaLaunchConfig_t config; - config.gridDim = nblks; - config.blockDim = nthrs; - config.stream = stream; - config.dynamicSmemBytes = 0; - config.attrs = attribute; - config.numAttrs = 1; - - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; - // For MLA: pass v_in as nullptr, num_kv_heads=1, duplicate 2D K strides for head strides, and - // 0 V strides - DType* v_in_nullptr = nullptr; - uint32_t num_kv_heads_1 = 1; - size_t k_rope_in_stride_h_dup = k_rope_in_stride; - size_t k_nope_in_stride_h_dup = k_nope_in_stride; - size_t v_in_stride_zero = 0, v_in_stride_h_zero = 0; - RopeQuantizeAppendPagedKVCacheParams params; - params.nnz = nnz; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = 1u; - params.rope_dim = rope_dim; - params.no_rope_dim = no_rope_dim; - params.q_rope_in_stride_n = q_rope_in_stride_n; - params.q_rope_in_stride_h = q_rope_in_stride_h; - params.q_nope_in_stride_n = q_nope_in_stride_n; - params.q_nope_in_stride_h = q_nope_in_stride_h; - params.q_rope_out_stride_n = q_rope_out_stride_n; - params.q_rope_out_stride_h = q_rope_out_stride_h; - params.q_nope_out_stride_n = q_nope_out_stride_n; - params.q_nope_out_stride_h = q_nope_out_stride_h; - params.k_rope_in_stride = k_rope_in_stride; - params.k_rope_in_stride_h = k_rope_in_stride_h_dup; - params.k_nope_in_stride = k_nope_in_stride; - params.k_nope_in_stride_h = k_nope_in_stride_h_dup; - params.v_in_stride = 0; - params.v_in_stride_h = 0; - params.quant_scale_q = quant_scale_q; - params.quant_scale_kv = quant_scale_kv; - - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, - // inputs - q_rope_in, k_rope_in, q_nope_in, k_nope_in, - v_in_nullptr, - // q outputs - q_rope_out, q_nope_out, - // cache + indices - paged_kv_mla, batch_indices, positions, - // rope tables - cos_sin_cache, pos_ids, - // params - params)); - }); + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; + bdx = std::max(1u, bdx); + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = std::max(1u, num_threads / bdx); + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + constexpr uint32_t num_kv_heads = 1; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + DType* v_in_nullptr = nullptr; + uint32_t num_kv_heads_1 = 1; + size_t k_rope_in_stride_h_dup = k_rope_in_stride; + size_t k_nope_in_stride_h_dup = k_nope_in_stride; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = 1u; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h_dup; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h_dup; + params.v_in_stride = 0; + params.v_in_stride_h = 0; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, + v_in_nullptr, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv_mla, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); }); return cudaSuccess; @@ -1253,65 +1291,93 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, cudaStream_t stream = nullptr) { - int dev_id = 0; - int num_sms = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - // operate on 16 Bytes at a time - constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); - // how many threads needed per head_dim - constexpr uint32_t bdx = HEAD_DIM / vec_size; - // how many threads needed per block - uint32_t num_threads = std::max(128U, bdx); - // how many tokens can we process in a block - uint32_t bdy = num_threads / bdx; - // how many blocks needed to process all tokens - uint32_t nblks_x = (nnz + bdy - 1) / bdy; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&q_rope, - (void*)&k_rope, - (void*)&cos_sin_cache, - (void*)&pos_ids, - (void*)&nnz, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&rotary_dim, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&k_stride_n, - (void*)&k_stride_h, - (void*)&q_rope_stride_n, - (void*)&q_rope_stride_h, - (void*)&k_rope_stride_n, - (void*)&k_rope_stride_h}; - auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel; + if (head_dim < rotary_dim) { + std::ostringstream err_msg; + err_msg << "head_dim (" << head_dim << ") must be >= rotary_dim (" << rotary_dim << ")"; + FLASHINFER_ERROR(err_msg.str()); + } - int num_blocks_per_sm_0 = 0; - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0)); - uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms; + // We have better performance with this kernel with these head_dim instead of RopeQuantize + if (head_dim == 64 || head_dim == 128 || head_dim == 256 || head_dim == 512) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - if ((nnz + bdy - 1) / bdy >= num_ctas_0) { - dim3 nblks(nblks_x); - dim3 nthrs(bdx, bdy); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); - } else { - dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); - dim3 nthrs(bdx, bdy); - auto kernel_1 = - BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); - } + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + // operate on 16 Bytes at a time + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + // how many threads needed per head_dim + constexpr uint32_t bdx = HEAD_DIM / vec_size; + // how many threads needed per block + uint32_t num_threads = std::max(128U, bdx); + // how many tokens can we process in a block + uint32_t bdy = num_threads / bdx; + // how many blocks needed to process all tokens + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&cos_sin_cache, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h}; + auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel; + + int num_blocks_per_sm_0 = 0; + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0)); + uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms; + + if ((nnz + bdy - 1) / bdy >= num_ctas_0) { + dim3 nblks(nblks_x); + dim3 nthrs(bdx, bdy); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); + } else { + dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); + dim3 nthrs(bdx, bdy); + auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel< + INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); + } + }); }); - }); + return cudaSuccess; + } - return cudaSuccess; + const uint32_t rope_dim = rotary_dim; + const uint32_t no_rope_dim = head_dim - rotary_dim; + + // Route to RopeQuantize kernel + DType* q_rope_in = q; + DType* k_rope_in = k; + DType* q_nope_in = q + rotary_dim; + DType* k_nope_in = k + rotary_dim; + DType* q_rope_out = q_rope; + DType* k_rope_out = k_rope; + DType* q_nope_out = q_rope + rotary_dim; + DType* k_nope_out = k_rope + rotary_dim; + + return RopeQuantize( + q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, q_nope_out, k_nope_out, + cos_sin_cache, pos_ids, nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, q_stride_n, + q_stride_h, q_stride_n, q_stride_h, q_rope_stride_n, q_rope_stride_h, q_rope_stride_n, + q_rope_stride_h, k_stride_n, k_stride_h, k_stride_n, k_stride_h, k_rope_stride_n, + k_rope_stride_h, k_rope_stride_n, k_rope_stride_h, /*quant_scale_q=*/1.0f, + /*quant_scale_kv=*/1.0f, interleave, /*enable_pdl=*/false, stream); } template diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 8e694088e5..651f43d5e9 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -300,6 +300,10 @@ def forward_cuda( (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), (256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2), + (192, 128, 4096, 9231, True, torch.bfloat16, "cuda", 3, 231, 3, 2), + (80, 64, 1024, 10000, False, torch.bfloat16, "cuda", 4, 64, 2, 2), + (112, 64, 2048, 12000, True, torch.bfloat16, "cuda", 5, 77, 2, 1), + (160, 96, 8192, 10000, False, torch.bfloat16, "cuda", 2, 128, 6, 3), ], ) def test_rope_cos_sin_cache(