-
Notifications
You must be signed in to change notification settings - Fork 581
perf: using multi-cta optimization for top-k/top-p #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe PR extends the top-k logits masking implementation to support multi-CTA (Cooperative Thread Block) kernels. A new Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python Layer
participant Binding as CUDA Binding
participant Kernel as TopK Kernel
participant CTA as CTA Group
Python->>Python: Allocate 1MB row_states_buffer (zero-init)
Python->>Binding: top_k_mask_logits(logits, top_k_val, row_states_buffer)
Binding->>Binding: Validate row_states_buffer
Binding->>Kernel: Launch TopKMaskLogitsMultiCTA (multi-CTA variant)
Note over Kernel: Dynamic SMEM sizing & grid config
Kernel->>CTA: Distribute work across CTA groups
par Multi-CTA Coordination
CTA->>CTA: Per-row reduction using double buffers
CTA->>CTA: Atomic min/max updates to RowReductionState
CTA->>CTA: Synchronization via acquire/release operations
end
CTA->>CTA: Barrier sync at chunk boundaries
Kernel->>Binding: Write updated mask_logits
Binding->>Python: Return masked logits
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance enhancement for the top-k/top-p renormalization and sampling routines. By refactoring the CUDA kernels to leverage a multi-CTA architecture and prioritizing shared memory usage for logits and probabilities, the changes aim to drastically reduce memory latency and improve throughput, especially for operations involving multiple scan rounds. This optimization builds upon previous work and focuses on efficient parallel processing across the GPU. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a multi-CTA optimization for top-k/top-p sampling, which is a significant performance enhancement. The implementation is well-structured, leveraging advanced CUDA features like inter-CTA synchronization via atomic operations and memory fences to efficiently process large vocabularies. The changes in the Python bindings and C++ interface are consistent with the new kernel's requirements. However, I've identified a critical bug in the kernel launch configuration logic within TopKMaskLogitsMultiCTA. The calculation for chunk_size can lead to requesting more shared memory than available, which would cause kernel launch failures under certain conditions. I have provided a detailed comment with a suggested fix for this issue. Overall, this is a great performance improvement, and with the suggested fix, it should be robust.
| constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| // Calculate how many CTAs needed per row | ||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| // Round up chunk_size to multiple of VEC_SIZE | ||
| chunk_size = round_up(chunk_size, VEC_SIZE); | ||
| // Ensure minimum chunk size | ||
| chunk_size = std::max(chunk_size, min_chunk_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for calculating chunk_size is incorrect and can lead to requesting more shared memory than is available, causing a kernel launch failure.
Specifically:
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);at line 2468 can inflatemax_chunk_elementsbeyond the available shared memory.chunk_size = std::max(chunk_size, min_chunk_size);at line 2476 can similarly causechunk_sizeto exceed shared memory limits, as it ignores themax_chunk_elementsconstraint.
This can happen if the available shared memory is small, making max_chunk_elements smaller than min_chunk_size.
I suggest replacing this block with logic that validates against min_chunk_size instead of forcing it, to ensure the kernel configuration is always valid.
constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
if (max_chunk_elements < min_chunk_size) {
// Not enough shared memory for even the minimum chunk size.
return cudaErrorInvalidConfiguration;
}
// Calculate how many CTAs needed per row
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
// Round up chunk_size to multiple of VEC_SIZE
chunk_size = round_up(chunk_size, VEC_SIZE);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/flashinfer_sampling_binding.cu(1 hunks)csrc/renorm.cu(1 hunks)flashinfer/sampling.py(3 hunks)flashinfer/utils.py(1 hunks)include/flashinfer/sampling.cuh(1 hunks)include/flashinfer/utils.cuh(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(272-274)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
register_custom_op(319-328)register_custom_op(338-357)_get_cache_buf(206-217)csrc/flashinfer_sampling_binding.cu (1)
top_k_mask_logits(57-59)csrc/renorm.cu (2)
top_k_mask_logits(61-84)top_k_mask_logits(61-63)flashinfer/logits_processor/operators.py (1)
_to_tensor_scalar_tuple(28-34)
🪛 Ruff (0.14.5)
flashinfer/sampling.py
407-407: Unused function argument: row_states_buffer
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
|
||
| def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: | ||
| def _get_cache_buf( | ||
| name: str, bytes: int, device: torch.device, zero_init: bool = False | ||
| ) -> torch.Tensor: | ||
| key = (name, device) | ||
| buf = _cache_buf.get(key) | ||
| if buf is None or buf.size(0) < bytes: | ||
| buf = torch.empty(bytes, dtype=torch.uint8, device=device) | ||
| if zero_init: | ||
| buf = torch.zeros(bytes, dtype=torch.uint8, device=device) | ||
| else: | ||
| buf = torch.empty(bytes, dtype=torch.uint8, device=device) | ||
| _cache_buf[key] = buf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zero-initialized cache must be cleared on reuse
When zero_init=True we only zero the tensor on first allocation; subsequent reuses skip the zero-fill. The new multi-CTA top-k path stores RowReductionState objects in this cache and assumes every launch starts from a fully cleared buffer. A reused buffer now comes back with stale counters/min/max values, so the first CTA observes non-zero state and the reductions diverge (easy to repro by calling top_k_mask_logits twice on the same device). Please zero the buffer whenever zero_init is requested.
if buf is None or buf.size(0) < bytes:
if zero_init:
buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
else:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
+ elif zero_init:
+ buf.zero_()
return buf🤖 Prompt for AI Agents
In flashinfer/utils.py around lines 205 to 216, the cache allocator only zeroes
the tensor on first allocation but does not clear reused buffers when
zero_init=True; update the function so that when an existing cached buffer is
found and zero_init is True you explicitly zero it (e.g., buf.zero_() or
buf.fill_(0)) before returning/using it, and keep the existing behavior of
allocating a zeroed tensor for new buffers; ensure the zeroing runs on the
correct device and dtype (torch.uint8).
| __device__ __forceinline__ void red_release(int* ptr, int val) { | ||
| #if (__CUDA_ARCH__ >= 700) | ||
| // SM70 and newer use memory consistency qualifiers | ||
| // Release pattern using acq_rel fence + relaxed modifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some more clarifiction: besides releasing, this also performs a reduction (sum)
| int persistent_iteration = 0; | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shoud we add tests that explicitly trigger num_groups > 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do have, for vocab_size=128256, it will be splitted into 4 chunks (one per SM), so when batch_size is greater than 33, the num_groups will be greater than 132 (the number of SMs on hopper) for H100.
📌 Description
Followup of #2044 , this PR optimizes the top-k/top-p renorm/sampling using multi-cta optimizations, more specifically:
The major advantage over the main branch is that we make sure logits/probabilities are stored in shared memory, so multi-round of scan doesn't affect performance too much.
Note that we also tried radix top-k (in https://github.com/yzh119/flashinfer-dev/tree/radix-top-k) which do not show better performance benefit than our multi-pivot binary search, we might try it in future PRs.
Speedup
On H100, the speedup over 0.5.2 is available at:
https://docs.google.com/spreadsheets/d/1DO8_11gzv-EUACCY6q4IMIHa8SaYv4q8hJ6gZl-D0mU/edit?usp=sharing
On consumer GPUs (e.g. Ada6000), the gap is even larger, e.g. for small batch size and large vocabulary setting:
The gap can be as large as 10 times.
🔍 Related Issues
#2044
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.