Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 20, 2025

📌 Description

Followup of #2044 , this PR optimizes the top-k/top-p renorm/sampling using multi-cta optimizations, more specifically:

  • split the vocabulary into chunks and let each cta handles one chunk
  • make sure the logits/probs can fit into shared memory inside a CTA
  • use global memory to store data structures for cross-cta synchronization.
  • make sure the total number of ctas (a multiple of num_chunks) do no exceed number of SMs, using loop to iterate over rows when batch size is greater than number of groups (num_ctas/num_chunks).

The major advantage over the main branch is that we make sure logits/probabilities are stored in shared memory, so multi-round of scan doesn't affect performance too much.

Note that we also tried radix top-k (in https://github.com/yzh119/flashinfer-dev/tree/radix-top-k) which do not show better performance benefit than our multi-pivot binary search, we might try it in future PRs.

Speedup

On H100, the speedup over 0.5.2 is available at:
https://docs.google.com/spreadsheets/d/1DO8_11gzv-EUACCY6q4IMIHa8SaYv4q8hJ6gZl-D0mU/edit?usp=sharing

On consumer GPUs (e.g. Ada6000), the gap is even larger, e.g. for small batch size and large vocabulary setting:

v0.5.2
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 10, duration: 3004.42 us, effective bandwidth: 2.73 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 100, duration: 3633.15 us, effective bandwidth: 2.25 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 1000, duration: 4258.82 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 5000, duration: 4256.77 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 10, duration: 2376.80 us, effective bandwidth: 3.45 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 100, duration: 3627.01 us, effective bandwidth: 2.26 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 1000, duration: 3945.38 us, effective bandwidth: 2.08 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 5000, duration: 4259.84 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 10, duration: 3316.74 us, effective bandwidth: 2.47 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 100, duration: 3624.96 us, effective bandwidth: 2.26 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 1000, duration: 4566.02 us, effective bandwidth: 1.79 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 5000, duration: 4576.26 us, effective bandwidth: 1.79 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 10, duration: 3003.39 us, effective bandwidth: 2.73 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 100, duration: 4260.86 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 1000, duration: 3947.52 us, effective bandwidth: 2.08 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 5000, duration: 5514.24 us, effective bandwidth: 1.49 GB/s

this PR
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 10, duration: 322.56 us, effective bandwidth: 25.40 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 1000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 5000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 10, duration: 257.02 us, effective bandwidth: 31.87 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 1000, duration: 421.89 us, effective bandwidth: 19.42 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 5000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 10, duration: 355.33 us, effective bandwidth: 23.05 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 1000, duration: 486.40 us, effective bandwidth: 16.84 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 5000, duration: 488.45 us, effective bandwidth: 16.77 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 10, duration: 320.51 us, effective bandwidth: 25.56 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 100, duration: 452.61 us, effective bandwidth: 18.10 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 1000, duration: 422.91 us, effective bandwidth: 19.37 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 5000, duration: 585.73 us, effective bandwidth: 13.99 GB/s

The gap can be as large as 10 times.

🔍 Related Issues

#2044

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Enhanced top-k sampling with multi-CTA kernel support for improved throughput and efficiency on larger batches.
    • Improved GPU memory management with optimized buffer allocation strategy for sampling operations.
  • Refactor

    • Internal improvements to sampling infrastructure with optimized device synchronization and utility functions.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The PR extends the top-k logits masking implementation to support multi-CTA (Cooperative Thread Block) kernels. A new row_states_buffer parameter is added throughout the call chain: Python bindings receive and allocate the buffer, CUDA bindings accept it, and the C++ core uses it with a new multi-CTA kernel variant. Device-level synchronization primitives are introduced, and utility functions are enhanced with compile-time qualifiers.

Changes

Cohort / File(s) Summary
Core Binding Updates
csrc/flashinfer_sampling_binding.cu, csrc/renorm.cu
Extended top_k_mask_logits function signatures to accept row_states_buffer parameter; updated kernel invocation to use multi-CTA variant with RowReductionState pointer.
CUDA Kernel & Device Helpers
include/flashinfer/sampling.cuh
Introduced multi-CTA top-k masking path: new atomic helpers (atomicMinFloat, atomicMaxFloat), synchronization primitives (ld_acquire, red_release, st_release, wait_ge), RowReductionState<T> struct, and dual kernels (TopKMaskLogitsKernel_MultiCTA, TopKMaskLogitsMultiCTA) with dynamic shared memory and grid configuration logic.
Python Wrapper
flashinfer/sampling.py
Extended top_k_mask_logits signature to include row_states_buffer parameter; allocated 1MB per-call buffer via _get_cache_buf; updated fake implementation to accept and propagate buffer; adjusted return type to float32.
Utility Functions
include/flashinfer/utils.cuh
Added constexpr and noexcept qualifiers to ceil_div and round_up; introduced new round_down utility function.
Cache Utilities
flashinfer/utils.py
Extended _get_cache_buf helper with zero_init parameter to support zero-initialized buffer allocation.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Python Layer
    participant Binding as CUDA Binding
    participant Kernel as TopK Kernel
    participant CTA as CTA Group

    Python->>Python: Allocate 1MB row_states_buffer (zero-init)
    Python->>Binding: top_k_mask_logits(logits, top_k_val, row_states_buffer)
    
    Binding->>Binding: Validate row_states_buffer
    Binding->>Kernel: Launch TopKMaskLogitsMultiCTA (multi-CTA variant)
    
    Note over Kernel: Dynamic SMEM sizing & grid config
    Kernel->>CTA: Distribute work across CTA groups
    
    par Multi-CTA Coordination
        CTA->>CTA: Per-row reduction using double buffers
        CTA->>CTA: Atomic min/max updates to RowReductionState
        CTA->>CTA: Synchronization via acquire/release operations
    end
    
    CTA->>CTA: Barrier sync at chunk boundaries
    Kernel->>Binding: Write updated mask_logits
    Binding->>Python: Return masked logits
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Multi-CTA synchronization logic: The new atomic helpers, release/acquire primitives, and barrier synchronization in TopKMaskLogitsKernel_MultiCTA require careful verification of correctness and thread-safety.
  • Device code complexity: The include/flashinfer/sampling.cuh kernel addition involves dynamic shared memory management, grid/block sizing heuristics, and cross-CTA coordination that demand close scrutiny.
  • Cross-layer integration: Changes span Python bindings, CUDA C++ bindings, and device kernels, requiring tracing the full call path and buffer lifecycle.
  • RowReductionState struct usage: Verify buffer layout, offset calculations, and double-buffering scheme for correctness.

Possibly related PRs

Suggested reviewers

  • bkryu
  • cyx-6
  • djmmoss
  • aleozlx
  • wenscarl
  • nvmbreughe

Poem

🐰 Hoppy tails and CUDA streams,
Multi-CTA threads team up with gleams,
Row states synchronize with care,
Atomics dance through shared memory's air,
Top-K masking now scales beyond one block's dream!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.29% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: multi-CTA optimization for top-k/top-p operations, which is the primary focus of this PR.
Description check ✅ Passed The PR description comprehensively covers the objectives, implementation details, performance improvements, and related issue references with all required template sections completed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement for the top-k/top-p renormalization and sampling routines. By refactoring the CUDA kernels to leverage a multi-CTA architecture and prioritizing shared memory usage for logits and probabilities, the changes aim to drastically reduce memory latency and improve throughput, especially for operations involving multiple scan rounds. This optimization builds upon previous work and focuses on efficient parallel processing across the GPU.

Highlights

  • Multi-CTA Optimization: The top-k/top-p renorm/sampling process has been optimized using a multi-CTA approach, where the vocabulary is split into chunks, and each CTA processes a chunk.
  • Shared Memory Utilization: Logits and probabilities are now ensured to fit into shared memory within each CTA, significantly improving performance by reducing global memory access during multi-round scans.
  • Cross-CTA Synchronization: Global memory is utilized to store data structures necessary for efficient synchronization between CTAs, enabling coordinated processing across the GPU.
  • Dynamic Grid Sizing: The total number of CTAs is dynamically managed to not exceed the number of Streaming Multiprocessors (SMs), with a looping mechanism to iterate over rows when the batch size is larger than the number of available groups.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a multi-CTA optimization for top-k/top-p sampling, which is a significant performance enhancement. The implementation is well-structured, leveraging advanced CUDA features like inter-CTA synchronization via atomic operations and memory fences to efficiently process large vocabularies. The changes in the Python bindings and C++ interface are consistent with the new kernel's requirements. However, I've identified a critical bug in the kernel launch configuration logic within TopKMaskLogitsMultiCTA. The calculation for chunk_size can lead to requesting more shared memory than available, which would cause kernel launch failures under certain conditions. I have provided a detailed comment with a suggested fix for this issue. Overall, this is a great performance improvement, and with the suggested fix, it should be robust.

Comment on lines +2467 to +2476
constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

// Calculate how many CTAs needed per row
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
// Round up chunk_size to multiple of VEC_SIZE
chunk_size = round_up(chunk_size, VEC_SIZE);
// Ensure minimum chunk size
chunk_size = std::max(chunk_size, min_chunk_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for calculating chunk_size is incorrect and can lead to requesting more shared memory than is available, causing a kernel launch failure.

Specifically:

  • max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); at line 2468 can inflate max_chunk_elements beyond the available shared memory.
  • chunk_size = std::max(chunk_size, min_chunk_size); at line 2476 can similarly cause chunk_size to exceed shared memory limits, as it ignores the max_chunk_elements constraint.

This can happen if the available shared memory is small, making max_chunk_elements smaller than min_chunk_size.

I suggest replacing this block with logic that validates against min_chunk_size instead of forcing it, to ensure the kernel configuration is always valid.

      constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
      if (max_chunk_elements < min_chunk_size) {
        // Not enough shared memory for even the minimum chunk size.
        return cudaErrorInvalidConfiguration;
      }

      // Calculate how many CTAs needed per row
      uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
      uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
      // Round up chunk_size to multiple of VEC_SIZE
      chunk_size = round_up(chunk_size, VEC_SIZE);

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 049e8db and 51af95c.

📒 Files selected for processing (6)
  • csrc/flashinfer_sampling_binding.cu (1 hunks)
  • csrc/renorm.cu (1 hunks)
  • flashinfer/sampling.py (3 hunks)
  • flashinfer/utils.py (1 hunks)
  • include/flashinfer/sampling.cuh (1 hunks)
  • include/flashinfer/utils.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • _get_cache_buf (206-217)
csrc/flashinfer_sampling_binding.cu (1)
  • top_k_mask_logits (57-59)
csrc/renorm.cu (2)
  • top_k_mask_logits (61-84)
  • top_k_mask_logits (61-63)
flashinfer/logits_processor/operators.py (1)
  • _to_tensor_scalar_tuple (28-34)
🪛 Ruff (0.14.5)
flashinfer/sampling.py

407-407: Unused function argument: row_states_buffer

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines 205 to 216

def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
def _get_cache_buf(
name: str, bytes: int, device: torch.device, zero_init: bool = False
) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None or buf.size(0) < bytes:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
if zero_init:
buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
else:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Zero-initialized cache must be cleared on reuse

When zero_init=True we only zero the tensor on first allocation; subsequent reuses skip the zero-fill. The new multi-CTA top-k path stores RowReductionState objects in this cache and assumes every launch starts from a fully cleared buffer. A reused buffer now comes back with stale counters/min/max values, so the first CTA observes non-zero state and the reductions diverge (easy to repro by calling top_k_mask_logits twice on the same device). Please zero the buffer whenever zero_init is requested.

     if buf is None or buf.size(0) < bytes:
         if zero_init:
             buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
         else:
             buf = torch.empty(bytes, dtype=torch.uint8, device=device)
         _cache_buf[key] = buf
+    elif zero_init:
+        buf.zero_()
     return buf
🤖 Prompt for AI Agents
In flashinfer/utils.py around lines 205 to 216, the cache allocator only zeroes
the tensor on first allocation but does not clear reused buffers when
zero_init=True; update the function so that when an existing cached buffer is
found and zero_init is True you explicitly zero it (e.g., buf.zero_() or
buf.fill_(0)) before returning/using it, and keep the existing behavior of
allocating a zeroed tensor for new buffers; ensure the zeroing runs on the
correct device and dtype (torch.uint8).

__device__ __forceinline__ void red_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
// SM70 and newer use memory consistency qualifiers
// Release pattern using acq_rel fence + relaxed modifier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some more clarifiction: besides releasing, this also performs a reduction (sum)

int persistent_iteration = 0;

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoud we add tests that explicitly trigger num_groups > 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have, for vocab_size=128256, it will be splitted into 4 chunks (one per SM), so when batch_size is greater than 33, the num_groups will be greater than 132 (the number of SMs on hopper) for H100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants