From 51af95c84a3f0a30dc539c10a1f8096bea9e9330 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 20 Nov 2025 07:34:25 +0000 Subject: [PATCH] upd --- csrc/flashinfer_sampling_binding.cu | 3 +- csrc/renorm.cu | 12 +- flashinfer/sampling.py | 25 +- flashinfer/utils.py | 9 +- include/flashinfer/sampling.cuh | 440 ++++++++++++++++++++++++++++ include/flashinfer/utils.cuh | 9 +- 6 files changed, 487 insertions(+), 11 deletions(-) diff --git a/csrc/flashinfer_sampling_binding.cu b/csrc/flashinfer_sampling_binding.cu index 8e4bbb98b8..bcf5f98ee0 100644 --- a/csrc/flashinfer_sampling_binding.cu +++ b/csrc/flashinfer_sampling_binding.cu @@ -55,7 +55,8 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_k_arr, int64_t top_k_val); void top_k_mask_logits(TensorView logits, TensorView mask_logits, - Optional maybe_top_k_arr, int64_t top_k_val); + Optional maybe_top_k_arr, int64_t top_k_val, + TensorView row_states_buffer); void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_ids, TensorView target_probs, TensorView output_token_ids, diff --git a/csrc/renorm.cu b/csrc/renorm.cu index 1e2aa45769..a1054bc03c 100644 --- a/csrc/renorm.cu +++ b/csrc/renorm.cu @@ -59,8 +59,10 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, } void top_k_mask_logits(TensorView logits, TensorView mask_logits, - Optional maybe_top_k_arr, int64_t top_k_val) { + Optional maybe_top_k_arr, int64_t top_k_val, + TensorView row_states_buffer) { CHECK_INPUT(logits); + CHECK_INPUT(row_states_buffer); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = logits.size(0); unsigned int vocab_size = logits.size(1); @@ -68,10 +70,14 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits, cudaSetDevice(logits.device().device_id); auto stream = get_stream(logits.device()); - cudaError_t status = sampling::TopKMaskLogits( + + cudaError_t status; + // Use multi-CTA kernel + status = sampling::TopKMaskLogitsMultiCTA( static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size, - top_k_val, vocab_size, stream); + top_k_val, vocab_size, + static_cast*>(row_states_buffer.data_ptr()), stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKMaskLogits failed with error code " << cudaGetErrorString(status); diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 3ac6367ff5..deaa3e8e0d 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -377,20 +377,25 @@ def _fake_top_k_renorm_probs( # torch library for top_k_mask_logits - @register_custom_op("flashinfer::top_k_mask_logits", mutates_args=()) + @register_custom_op( + "flashinfer::top_k_mask_logits", mutates_args=("row_states_buffer",) + ) def top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, + row_states_buffer: torch.Tensor, ) -> torch.Tensor: logits = logits.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None mask_logits = torch.empty_like(logits) + module.top_k_mask_logits( logits, mask_logits, maybe_top_k_arr, top_k_val, + row_states_buffer, ) return mask_logits @@ -399,8 +404,9 @@ def _fake_top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, + row_states_buffer: torch.Tensor, ) -> torch.Tensor: - return torch.empty_like(logits) + return torch.empty_like(logits, dtype=torch.float32) # torch library for chain_speculative_sampling @@ -1346,8 +1352,21 @@ def top_k_mask_logits( top_k_renorm_probs """ _check_tensor_param(top_k, logits) + + # Allocate row_states buffer for multi-CTA kernel (1MB is enough for any GPU) + buffer_bytes = 1024 * 1024 # 1MB + row_states_buffer = _get_cache_buf( + f"top_k_mask_logits_row_states_{logits.device}", + buffer_bytes, + logits.device, + zero_init=True, + ) + + # Note: row_states_buffer is zero-initialized on first allocation by _get_cache_buf + # Kernel will reset arrival_counter to 0 at the end of each launch + return get_sampling_module().top_k_mask_logits( - logits, *_to_tensor_scalar_tuple(top_k) + logits, *_to_tensor_scalar_tuple(top_k), row_states_buffer ) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 76689bab84..cfad7f591a 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -203,11 +203,16 @@ def get_alibi_slopes(n_heads: int) -> torch.Tensor: _cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} -def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: +def _get_cache_buf( + name: str, bytes: int, device: torch.device, zero_init: bool = False +) -> torch.Tensor: key = (name, device) buf = _cache_buf.get(key) if buf is None or buf.size(0) < bytes: - buf = torch.empty(bytes, dtype=torch.uint8, device=device) + if zero_init: + buf = torch.zeros(bytes, dtype=torch.uint8, device=device) + else: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) _cache_buf[key] = buf return buf diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 03d4bfa8e2..dba4935134 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -2071,6 +2071,446 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar }); } +// ==================== Multi-CTA Top-K Implementation ==================== + +// Atomic min/max for float using CAS +__device__ __forceinline__ float atomicMinFloat(float* addr, float value) { + int* addr_as_int = (int*)addr; + int old = *addr_as_int, assumed; + + do { + assumed = old; + old = atomicCAS(addr_as_int, assumed, __float_as_int(fminf(value, __int_as_float(assumed)))); + } while (assumed != old); + + return __int_as_float(old); +} + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + int* addr_as_int = (int*)addr; + int old = *addr_as_int, assumed; + + do { + assumed = old; + old = atomicCAS(addr_as_int, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + + return __int_as_float(old); +} + +// Acquire/Release primitives for inter-CTA synchronization +__device__ __forceinline__ int ld_acquire(int* ptr) { + int state = 0; + +#if (__CUDA_ARCH__ >= 700) + // SM70 and newer use memory consistency qualifiers + // Acquire pattern using acquire modifier + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#else + asm volatile("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif + + return state; +} + +__device__ __forceinline__ void red_release(int* ptr, int val) { +#if (__CUDA_ARCH__ >= 700) + // SM70 and newer use memory consistency qualifiers + // Release pattern using acq_rel fence + relaxed modifier + // (The fence also releases data that was weakly-written by other threads prior to the last + // syncthreads) + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#else + __threadfence(); + atomicAdd(ptr, val); +#endif +} + +__device__ __forceinline__ void st_release(int* ptr, int val) { +#if (__CUDA_ARCH__ >= 700) + // SM70 and newer use memory consistency qualifiers + // Release pattern: fence + release store + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#else + __threadfence(); + atomicExch(ptr, val); +#endif +} + +// Wait until the value at ptr reaches target_val using acquire semantics +// Only thread 0 spins, then all threads synchronize +__device__ __forceinline__ void wait_ge(int* ptr, int target_val, int thread_idx) { + if (thread_idx == 0) { +#pragma unroll 1 + while (ld_acquire(ptr) < target_val) { + } + } + __syncthreads(); +} + +// Global state for multi-CTA reduction (one per row) +template +struct RowReductionState { + // Ping-pong buffers for atomic reduction + int count_0_buf[2]; + int count_1_buf[2]; + T min_buf[2]; + T max_buf[2]; + + // Arrival counter for acquire/release synchronization + int arrival_counter; +}; + +template +__global__ void __launch_bounds__(BLOCK_THREADS) TopKMaskLogitsKernel_MultiCTA( + DType* logits, // [batch, vocab_size] + DType* masked_logits, // [batch, vocab_size] + IdType* top_k_arr, // [batch] or nullptr + uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size, + RowReductionState* row_states, // [num_groups], num_groups = gridDim.x / ctas_per_group + uint32_t chunk_size, // elements per CTA (must be multiple of VEC_SIZE) + uint32_t ctas_per_group) // CTAs per row +{ + const uint32_t global_cta_id = blockIdx.x; + const uint32_t group_id = global_cta_id / ctas_per_group; + const uint32_t cta_in_group = global_cta_id % ctas_per_group; + const uint32_t tx = threadIdx.x; + + // Shared memory layout: [temp_storage] [padding] [logits data (16-byte aligned)] + extern __shared__ uint8_t smem[]; + auto* temp_storage = reinterpret_cast*>(smem); + + // Align logits to 16 bytes + size_t temp_storage_size = sizeof(RenormTempStorage); + size_t logits_offset = ((temp_storage_size + 15) / 16) * 16; + DType* shared_logits = reinterpret_cast(smem + logits_offset); + + // Note: arrival_counter and count buffers should be pre-initialized to zero on the host side + + // Persistent iteration counter for double buffering (never resets across rows) + int persistent_iteration = 0; + + // Calculate total number of iterations for persistent loop + uint32_t num_groups = gridDim.x / ctas_per_group; + uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups; + + int barrier_phase = 0; + // Each group uses its own state (groups process rows sequentially in persistent loop) + RowReductionState* state = &row_states[group_id]; + + // Initialize min/max buffer for this row (first CTA only) + if (cta_in_group == 0 && tx == 0) { + state->min_buf[0] = cuda::std::numeric_limits::max(); + state->max_buf[0] = cuda::std::numeric_limits::lowest(); + } + + // First barrier: ensure all CTAs see the initialized min/max values + if (tx == 0) { + red_release(&state->arrival_counter, 1); + } + int target = (barrier_phase + 1) * ctas_per_group; + wait_ge(&state->arrival_counter, target, tx); + barrier_phase++; + + // Persistent loop over rows + for (uint32_t iter = 0; iter < total_iterations; iter++) { + uint32_t row_idx = group_id + iter * num_groups; + + if (row_idx >= batch_size) break; // Early exit if out of bounds + + const uint32_t chunk_start = cta_in_group * chunk_size; + const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size); + const uint32_t actual_chunk_size = chunk_end - chunk_start; + + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; + + // ========== Stage 1: Load to shared memory ========== + vec_t logits_vec; + const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; + + // Vectorized load for aligned portion +#pragma unroll 2 + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { + logits_vec.cast_load(logits + row_idx * vocab_size + chunk_start + i); + logits_vec.store(shared_logits + i); + } + + // Scalar load for tail (only for last CTA if vocab_size not aligned) + for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { + shared_logits[i] = logits[row_idx * vocab_size + chunk_start + i]; + } + __syncthreads(); + + double pivot = -cuda::std::numeric_limits::infinity(); + + if (k < vocab_size) { + // ========== Stage 2: Initialize - find global min/max ========== + float local_min = cuda::std::numeric_limits::max(); + float local_max = cuda::std::numeric_limits::lowest(); + + // Vectorized min/max for aligned portion +#pragma unroll 2 + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { + logits_vec.load(shared_logits + i); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + float val = logits_vec[j]; + local_min = min(local_min, val); + local_max = max(local_max, val); + } + } + + // Scalar min/max for tail + for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { + float val = shared_logits[i]; + local_min = min(local_min, val); + local_max = max(local_max, val); + } + + // Block reduction + float block_min = + BlockReduce(temp_storage->block_prim.reduce) + .Reduce(local_min, MinReduceOp{}); + __syncthreads(); + + float block_max = + BlockReduce(temp_storage->block_prim.reduce) + .Reduce(local_max, MaxReduceOp{}); + __syncthreads(); + + // Atomic reduction to global state + if (tx == 0) { + atomicMinFloat(&state->min_buf[0], block_min); + atomicMaxFloat(&state->max_buf[0], block_max); + + // Signal arrival using release semantics + red_release(&state->arrival_counter, 1); + } + int target = (barrier_phase + 1) * ctas_per_group; + wait_ge(&state->arrival_counter, target, tx); + barrier_phase++; + + float global_min = state->min_buf[0]; + float global_max = state->max_buf[0]; + + // ========== Stage 3: Binary search ========== + double low = (global_min == -cuda::std::numeric_limits::infinity()) + ? cuda::std::numeric_limits::lowest() + : global_min - 1; + double high = global_max; + float min_gt_low, max_le_high; + + do { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + // Local counting from shared memory + int local_count_0 = 0, local_count_1 = 0; + float local_min_gt_low = high, local_max_le_high = low; + + // Vectorized counting for aligned portion +#pragma unroll 2 + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { + logits_vec.load(shared_logits + i); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + float val = logits_vec[j]; + // Branchless counting + local_count_0 += (val > pivot_0); + local_count_1 += (val > pivot_1); + // Update min/max + if (val > low) local_min_gt_low = min(local_min_gt_low, val); + if (val <= high) local_max_le_high = max(local_max_le_high, val); + } + } + + // Scalar counting for tail + for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { + float val = shared_logits[i]; + local_count_0 += (val > pivot_0); + local_count_1 += (val > pivot_1); + if (val > low) local_min_gt_low = min(local_min_gt_low, val); + if (val <= high) local_max_le_high = max(local_max_le_high, val); + } + + // Block reduction + int block_count_0 = + BlockReduce(temp_storage->block_prim.reduce_int) + .Sum(local_count_0); + __syncthreads(); + + int block_count_1 = + BlockReduce(temp_storage->block_prim.reduce_int) + .Sum(local_count_1); + __syncthreads(); + + float block_min_gt_low = + BlockReduce(temp_storage->block_prim.reduce) + .Reduce(local_min_gt_low, MinReduceOp{}); + __syncthreads(); + + float block_max_le_high = + BlockReduce(temp_storage->block_prim.reduce) + .Reduce(local_max_le_high, MaxReduceOp{}); + __syncthreads(); + + // Ping-pong buffer index (use persistent_iteration for double buffering) + int buffer_idx = persistent_iteration & 1; + + // Atomic reduction to global state + if (tx == 0) { + atomicAdd(&state->count_0_buf[buffer_idx], block_count_0); + atomicAdd(&state->count_1_buf[buffer_idx], block_count_1); + atomicMinFloat(&state->min_buf[buffer_idx], block_min_gt_low); + atomicMaxFloat(&state->max_buf[buffer_idx], block_max_le_high); + + // Signal arrival using release semantics + red_release(&state->arrival_counter, 1); + + // Last CTA clears next buffer (no need to reset counter anymore) + if (cta_in_group == ctas_per_group - 1) { + int next_buf = (persistent_iteration + 1) & 1; + state->count_0_buf[next_buf] = 0; + state->count_1_buf[next_buf] = 0; + state->min_buf[next_buf] = cuda::std::numeric_limits::max(); + state->max_buf[next_buf] = cuda::std::numeric_limits::lowest(); + } + } + int target = (barrier_phase + 1) * ctas_per_group; + wait_ge(&state->arrival_counter, target, tx); + barrier_phase++; + + // Read results from current buffer + int aggregate_gt_pivot_0 = state->count_0_buf[buffer_idx]; + int aggregate_gt_pivot_1 = state->count_1_buf[buffer_idx]; + min_gt_low = state->min_buf[buffer_idx]; + max_le_high = state->max_buf[buffer_idx]; + + // Update search range + if (aggregate_gt_pivot_1 >= k) { + low = pivot_1; + } else if (aggregate_gt_pivot_0 >= k) { + low = pivot_0; + high = min(pivot_1, max_le_high); + } else { + high = min(pivot_0, max_le_high); + } + + persistent_iteration++; + + } while (min_gt_low != max_le_high); + + pivot = low; + } + + // ========== Stage 4: Masking ========== + // Vectorized masking for aligned portion +#pragma unroll 2 + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { + logits_vec.load(shared_logits + i); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + logits_vec[j] = + (logits_vec[j] > pivot) ? logits_vec[j] : -cuda::std::numeric_limits::infinity(); + } + logits_vec.store(masked_logits + row_idx * vocab_size + chunk_start + i); + } + + // Scalar masking for tail + for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { + float val = shared_logits[i]; + masked_logits[row_idx * vocab_size + chunk_start + i] = + (val > pivot) ? val : -cuda::std::numeric_limits::infinity(); + } + } + + // Finalize: reset counter for this group to prepare for next kernel launch + // All iterations are done, safe to reset now + if (cta_in_group == 0 && tx == 0) { + st_release(&row_states[group_id].arrival_counter, 0); + } +} + +template +cudaError_t TopKMaskLogitsMultiCTA(DType* logits, DType* masked_logits, IdType* top_k_arr, + uint32_t batch_size, uint32_t top_k_val, uint32_t vocab_size, + RowReductionState* row_states_buffer, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + // Calculate aligned temp storage size + constexpr size_t temp_storage_size = sizeof(RenormTempStorage); + constexpr size_t temp_storage_aligned = round_up(temp_storage_size, 16UL); + + // Get device properties + int device; + FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); + int max_smem_per_block; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_block, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + + // Calculate max chunk size that fits in shared memory + // smem layout: [temp_storage_aligned] [chunk_size * sizeof(DType)] + const size_t available_for_logits = max_smem_per_block - temp_storage_aligned; + uint32_t max_chunk_elements = available_for_logits / sizeof(DType); + + // Round down to multiple of VEC_SIZE + max_chunk_elements = round_down(max_chunk_elements, VEC_SIZE); + + // Ensure minimum chunk size for vectorized access + constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS; + max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); + + // Calculate how many CTAs needed per row + uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); + uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); + // Round up chunk_size to multiple of VEC_SIZE + chunk_size = round_up(chunk_size, VEC_SIZE); + // Ensure minimum chunk size + chunk_size = std::max(chunk_size, min_chunk_size); + + // Get number of SMs + int num_sms; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device)); + + // Calculate grid size (must be multiple of ctas_per_group, up to num_sms) + uint32_t num_groups = std::min(static_cast(num_sms) / ctas_per_group, batch_size); + if (num_groups == 0) { + // vocab_size too large to fit in shared memory even with one chunk per SM + return cudaErrorInvalidConfiguration; + } + uint32_t total_ctas = num_groups * ctas_per_group; + + // Calculate shared memory size + const uint32_t smem_size = temp_storage_aligned + chunk_size * sizeof(DType); + + // Launch kernel + dim3 nblks(total_ctas); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &vocab_size, + &batch_size, &row_states_buffer, &chunk_size, &ctas_per_group}; + + auto kernel = + TopKMaskLogitsKernel_MultiCTA; + + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Use regular kernel launch via cudaLaunchKernel API + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + + return cudaSuccess; + }); + }); +} + template diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 0471bd1081..716e1a805b 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -317,15 +317,20 @@ namespace flashinfer { template -__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { +__forceinline__ __device__ __host__ constexpr T1 ceil_div(const T1 x, const T2 y) noexcept { return (x + y - 1) / y; } template -__forceinline__ __device__ __host__ T1 round_up(const T1 x, const T2 y) { +__forceinline__ __device__ __host__ constexpr T1 round_up(const T1 x, const T2 y) noexcept { return ceil_div(x, y) * y; } +template +__forceinline__ __device__ __host__ constexpr T1 round_down(const T1 x, const T2 y) noexcept { + return (x / y) * y; +} + inline std::pair GetCudaComputeCapability() { int device_id = 0; cudaGetDevice(&device_id);