-
Notifications
You must be signed in to change notification settings - Fork 578
Refactor trtllm_mnnvl_allreduce #2118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces the legacy MNNVL all-reduce with a fused lamport-buffer allreduce exposing Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant PyAPI as Python API / Workspace
participant Strategy as Strategy Selector
participant Comm as Comm Backend / IpcSocket / MPI
participant Kernel as CUDA Kernel (fusion)
participant Out as Output Buffers
App->>PyAPI: call trtllm_mnnvl_allreduce(...) / fused variant
PyAPI->>PyAPI: validate inputs, prepare workspace & outputs
PyAPI->>Strategy: select ONESHOT / TWOSHOT (AUTO may inspect sizes)
PyAPI->>Comm: exchange/share handles (MPI bcast/barrier or IpcSocket FD exchange)
PyAPI->>Kernel: invoke trtllm_mnnvl_allreduce_fusion with params (rmsnorm_fusion, use_oneshot,...)
rect rgb(235,245,255)
Kernel->>Kernel: lamport-stage broadcast & rank reduction
alt RMSNorm fusion enabled
Kernel->>Kernel: compute RMS, apply gamma, add residuals
end
Kernel->>Out: write output (and residual_out if present)
end
Out-->>PyAPI: return tensor(s)
PyAPI-->>App: deliver result(s)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive refactoring of the multi-node NVLink (MNNVL) all-reduce system within FlashInfer. It unifies the all-reduce and RMSNorm operations into a single, highly configurable C++ kernel, exposed through intuitive new Python APIs. A key improvement is the new workspace management class, which automates and optimizes buffer allocation. Furthermore, the PR adds crucial support for IPC Socket-based handle transfer, broadening compatibility to hardware environments like DGX machines. These changes collectively enhance the flexibility, performance, and overall robustness of distributed computations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring of the MNNVL all-reduce implementation, introducing a new, cleaner API with a dedicated workspace manager class, and adding support for IPC sockets for single-node communication. The changes are extensive and substantially improve the code's structure and capabilities. My review focuses on ensuring backward compatibility is fully maintained as intended, removing leftover debug code, improving memory usage efficiency, adding a critical safety check for buffer sizes in the new API, and suggesting a minor precision improvement in a CUDA kernel.
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new trtllm_mnnvl_allreduce function is missing a check to ensure that the input tensor fits within the allocated workspace. The old API had a check like if inp.shape[0] > buffer_M: raise ValueError(...). A similar check should be added here to prevent potential out-of-bounds memory access, which could lead to crashes or incorrect results. The required buffer size depends on the strategy (one-shot vs. two-shot) and can be calculated using MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this check to be on the execution path? Or should we assuming it is the user's liability to ensure it does not overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want this check. I recently added it because it did bite others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
| self.mcast_device_memory.lamport_initialize(rank, dtype) | ||
|
|
||
| def get_mc_buffer( | ||
| def get_multicast_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is used internally, and left as a placeholder but not implemented. Thus, a breaking changes is fine. Tag @nvmbreughe for confirmation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/comm/mnnvl.py (1)
132-149:alloc_and_copy_to_cudareturn type and empty-input behavior are inconsistentThe function is annotated as returning
intbut returnsNonewhenhost_ptr_arrayis empty. Callers currently pass non‑empty lists, but this mismatch can trip type checkers and hide bugs if an empty list is ever passed.Either tighten behavior or relax the signature, for example:
- If empty input is invalid, raise:
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: raise ValueError("host_ptr_array must be non-empty")
- Or, if you want the sentinel, change the annotation to
int | Noneand document theNonecase.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
328-427: Moveallgather()and finalmpi_barrier()tofinallyblock to ensure all ranks participate in collectivesLines 414 and 434 create a deadlock risk in error scenarios. The
allgather()at line 414 is inside theexceptblock, so only ranks that hit an exception call it. Meanwhile, thempi_barrier()at line 434 is unconditionally called aftertry/except/finally. If an error occurs on some but not all ranks, failing ranks block inallgather()waiting for non-failing ranks that never enter theexceptblock, while non-failing ranks block in the final barrier—both unable to proceed.Move the
allgather()call and finalmpi_barrier()to thefinallyblock to ensure all ranks participate in these collective operations:rank_failed = False try: ... except Exception as e: rank_failed = True failure_message = ... print(failure_message) import traceback print(traceback.format_exc()) raise finally: all_failures = MPI.COMM_WORLD.allgather(rank_failed) if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") if "workspace" in locals(): del workspace trtllm_mnnvl_ar.mpi_barrier()This applies to line 328–426 (main
try/except) and line 434 (final barrier).
🧹 Nitpick comments (8)
flashinfer/comm/mnnvl.py (1)
640-655: Minor polish: unused recvmsg outputs and predictableopIdTwo small, non‑blocking cleanups:
- In
IpcSocket.recv_fd(), the unpackedmsg,flags, andaddrfromrecvmsgare unused. Renaming them to_msg,_flags,_addrwill make that explicit and silence linters:_msg, ancdata, _flags, _addr = self.sock.recvmsg(...)
opIdfor the socket name is generated withrandom.randint. Since it’s only used as a uniqueness hint (not security‑sensitive), this is fine; if you want to appease S311 you could switch tosecrets.randbits(64)or document that it’s non‑cryptographic.Both are optional, but would make static analysis quieter.
Also applies to: 885-889
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)
23-25: Explicitly include<array>and<tuple>, and guardadjustGridConfigagainstsmCount == 0Within this header:
LamportBufferLayout,LamportFlags,PackedVec, and several kernels usestd::array.adjustGridConfigreturnsstd::tuple<int, int, int>and callers usestd::get.But only
<type_traits>is included;<array>and<tuple>are currently pulled in (if at all) via transitive includes, which is fragile.Also,
adjustGridConfigrelies onGetCudaMultiProcessorCount():int smCount = GetCudaMultiProcessorCount(); while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { ... }If
GetCudaMultiProcessorCount()ever returns 0 (e.g., CUDA error or misconfiguration), this loop will keep shrinkingclusterSizeand inflatingblockSizein a somewhat opaque way.Suggestions:
- Add explicit includes at the top of the header:
#include <array> #include <tuple>
- Make
adjustGridConfigrobust to a 0 or negative SM count by early‑clamping:int smCount = GetCudaMultiProcessorCount(); if (smCount <= 0) { // Fall back to single-SM configuration clusterSize = 1; blockSize = std::min(threadsNeeded, 1024); return {blockSize, clusterSize, loadsPerThread}; }This keeps the fused path predictable even if the helper cannot obtain a valid SM count.
Also applies to: 54-177, 143-163, 291-313, 348-359, 385-419, 449-497
509-651: Confirm lamport clear / wait protocol assumptions for oneshot kernelThe oneshot fused kernel uses
LamportFlagsas follows:
- Out‑of‑bounds threads call
ctaArrive()thenclearDirtyLamportBuf()and return.- In‑bounds threads:
- write their shard into the multicast lamport buffer,
- call
ctaArrive()again,- then call
clearDirtyLamportBuf()and spin on the Lamport buffers until all entries are non‑negZero.This protocol assumes:
- Every thread in the grid calls
clearDirtyLamportBuf()exactly once per iteration.- Buffer flags and
bytesToClearare correctly initialized to match the configurednumTokens * tokenDim * WorldSize.Given that this is a direct Lamport port, the logic looks consistent, but the protocol is subtle. I’d recommend:
- Double‑checking the initialization of
buffer_flagsinMNNVLAllreduceFusionWorkspacematches the expectations here (current index, dirty index, bytes per buffer, and stage counts).- Adding a brief comment near the kernel launch documenting that
buffer_flagsmust follow the[cur, dirty, bytes_per_buffer, dirty_num_stages, bytes_to_clear[4], access_ptr]layout used byLamportFlags.No code change strictly required, but the invariants are nontrivial and worth locking down in comments/tests.
754-885: Two‑shot path & RMSNorm fusion: validate world sizes and loads‑per‑thread boundsThe two‑shot kernels and dispatchers introduce several constraints:
twoshotAllreduceFusionDispatch<T>only supportsnRanksin{2, 4, 8, 16, 32, 64}and enforcestokenDim % (sizeof(float4) / sizeof(T)) == 0.rmsNormLamportis instantiated withLoadsPerThreadin[1, 8]and usesfloat4loads into shared memory; dynamic shared memory is sized as3 * rnBlockSize * iters * sizeof(T)and indexed accordingly.The implementation looks coherent, but a few invariants are implicit:
MNNVLTwoShotStage::NUM_STAGESmust stay in sync with theLamportFlags<float4>usage and the twobytes_to_clearentries inwaitAndUpdate.rnLoadsPerThreadretrieved fromadjustGridConfigmust remain in[1, 8]; thedefault:branch already errors if it’s out of range, which is good.rnClusterSizefromadjustGridConfigis assumed to be<= 8given__shared__ float sharedVal[8];in the RMSNorm kernel.Given these contracts, I’d suggest:
- Adding asserts (or comments) that
rnClusterSize <= 8when CGA is used, to guard future changes toadjustGridConfig.- Extending tests to cover the corner cases where
tokenDimis just at or above the supported boundary (e.g., maximum hidden size and multiple world sizes) so we don’t regress theFLASHINFER_CHECKconditions.Functionally the code looks sound; this is mainly about making the implicit constraints explicit.
Also applies to: 898-959, 1062-1219
csrc/trtllm_mnnvl_allreduce.cu (1)
99-107: Error message still mentions “twoshot” even for oneshot pathRegardless of
use_oneshot, the failure message says:TVM_FFI_ICHECK(status == cudaSuccess) << "twoshot_allreduce_dispatch_world_size failed with error code " << cudaGetErrorString(status);This is slightly misleading when the oneshot dispatch is used. Consider making the message neutral (e.g., “allreduce_fusion_dispatch failed…”) or switching on
use_oneshotto provide a more accurate label. Behavior is otherwise fine.tests/comm/test_trtllm_mnnvl_allreduce.py (2)
232-270: Use the sameepsfor reference RMSNorm as the fused kernelIn
prepare_test_data, the fused reference path uses:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )But the actual fused kernel is driven by the
epsargument passed intorow_linear_residual_norm_fusion_forward(eps = 1e-5inrun_mnnvl_ar_full).To keep the reference as close as possible to the fused implementation (and not rely on loose tolerances), consider:
def prepare_test_data(..., fusion: bool, eps: float): ... if fusion: ... norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)and threading
epsthrough the call sites.
273-281: Annotatelegacy_explicit_workspace_bytesas optionalRuff’s RUF013 warning here is valid:
def run_mnnvl_ar_full(..., legacy_explicit_workspace_bytes: int = None, legacy_api: bool = False, ):Changing the signature to make the optionality explicit improves readability and typing:
from typing import Optional def run_mnnvl_ar_full( ..., legacy_explicit_workspace_bytes: Optional[int] = None, legacy_api: bool = False, ) -> None: ...or, in Python 3.10+:
legacy_explicit_workspace_bytes: int | None = Noneflashinfer/comm/trtllm_mnnvl_ar.py (1)
203-205: Drop debug print from hot path.
This unconditional- print( - f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" - )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh(2 hunks)include/flashinfer/utils.cuh(1 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(301-391)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(229-298)get_allreduce_mnnvl_workspace(398-451)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (10)
barrier(168-168)barrier(227-228)bcast(165-165)bcast(224-225)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-222)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-109)trtllm_mnnvl_allreduce_fusion(31-37)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
tests/comm/test_trtllm_mnnvl_allreduce.py
279-279: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
261-263: Avoid specifying long messages outside the exception class
(TRY003)
268-270: Avoid specifying long messages outside the exception class
(TRY003)
338-340: Avoid specifying long messages outside the exception class
(TRY003)
342-344: Avoid specifying long messages outside the exception class
(TRY003)
346-348: Avoid specifying long messages outside the exception class
(TRY003)
352-354: Avoid specifying long messages outside the exception class
(TRY003)
358-360: Avoid specifying long messages outside the exception class
(TRY003)
500-502: Avoid specifying long messages outside the exception class
(TRY003)
571-573: Avoid specifying long messages outside the exception class
(TRY003)
577-579: Avoid specifying long messages outside the exception class
(TRY003)
582-584: Avoid specifying long messages outside the exception class
(TRY003)
586-588: Avoid specifying long messages outside the exception class
(TRY003)
591-593: Avoid specifying long messages outside the exception class
(TRY003)
596-598: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
885-885: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| if epsilon is None: | ||
| epsilon = torch.finfo(input.dtype).eps | ||
|
|
||
| if len(input.shape) != 2: | ||
| raise ValueError( | ||
| f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." | ||
| ) | ||
| if len(residual_in.shape) != 2: | ||
| raise ValueError( | ||
| f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." | ||
| ) | ||
| if gamma.numel() != input.shape[1]: | ||
| raise ValueError( | ||
| f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements." | ||
| ) | ||
| if output is None: | ||
| output = torch.empty_like(input) | ||
| elif len(output.shape) != 2: | ||
| raise ValueError( | ||
| f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." | ||
| ) | ||
| if residual_out is None: | ||
| residual_out = torch.empty_like(residual_in) | ||
| elif len(residual_out.shape) != 2: | ||
| raise ValueError( | ||
| f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}." | ||
| ) | ||
|
|
||
| module = get_trtllm_mnnvl_comm_module() | ||
|
|
||
| use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( | ||
| strategy == MNNVLAllreduceFusionStrategy.AUTO | ||
| and MNNVLAllreduceFusionStrategy.is_one_shot( | ||
| workspace.tp_size, | ||
| input.shape[0], | ||
| input.shape[1], | ||
| input.dtype, | ||
| ) | ||
| ) | ||
|
|
||
| module.trtllm_mnnvl_allreduce_fusion( | ||
| input, | ||
| workspace.mc_ptr, | ||
| workspace.uc_ptrs_dev, | ||
| workspace.uc_ptr_local, | ||
| workspace.buffer_flags, | ||
| workspace.tp_size, | ||
| workspace.rank, | ||
| True, # RMSNorm Fusion | ||
| launch_with_pdl, | ||
| use_oneshot, | ||
| output, | ||
| residual_out, | ||
| residual_in, | ||
| gamma, | ||
| epsilon, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore RMSNorm epsilon default to 1e-5.
Overriding epsilon with torch.finfo(...).eps replaces the kernel’s built-in 1e-5 default (see trtllm_mnnvl_allreduce_fusion in csrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM. Please keep the default at 1e-5 (or leave the argument unset so the kernel default is used).
- if epsilon is None:
- epsilon = torch.finfo(input.dtype).eps
+ if epsilon is None:
+ epsilon = 1e-5🧰 Tools
🪛 Ruff (0.14.5)
338-340: Avoid specifying long messages outside the exception class
(TRY003)
342-344: Avoid specifying long messages outside the exception class
(TRY003)
346-348: Avoid specifying long messages outside the exception class
(TRY003)
352-354: Avoid specifying long messages outside the exception class
(TRY003)
358-360: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around lines 334-389, the code currently
overrides epsilon with torch.finfo(input.dtype).eps when epsilon is None which
changes RMSNorm behavior (especially for fp16); restore the original default by
either leaving epsilon as None so the CUDA kernel's built-in 1e-5 is used or
explicitly set epsilon = 1e-5 when epsilon is None, and remove the
torch.finfo(...) fallback so we do not substitute ~1e-3 for fp16.
include/flashinfer/utils.cuh
Outdated
| inline int GetCudaMultiProcessorCount() { | ||
| static int sm_count = 0; | ||
| if (sm_count == 0) { | ||
| int device_id; | ||
| cudaGetDevice(&device_id); | ||
| cudaDeviceProp device_prop; | ||
| cudaGetDeviceProperties(&device_prop, device_id); | ||
| sm_count = device_prop.multiProcessorCount; | ||
| } | ||
| return sm_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make GetCudaMultiProcessorCount thread‑safe and clarify multi‑device semantics
static int sm_countis written without synchronization; concurrent calls from multiple host threads can cause a data race and undefined behavior.- The function also permanently caches the SM count of whichever device is current on the first call; if the process later switches devices, the cached value will be wrong.
Consider making the cache atomic (or using std::call_once) and, if needed, keying by device ID. For example:
- inline int GetCudaMultiProcessorCount() {
- static int sm_count = 0;
- if (sm_count == 0) {
- int device_id;
- cudaGetDevice(&device_id);
- cudaDeviceProp device_prop;
- cudaGetDeviceProperties(&device_prop, device_id);
- sm_count = device_prop.multiProcessorCount;
- }
- return sm_count;
-}
+ inline int GetCudaMultiProcessorCount() {
+ static std::atomic<int> sm_count{0};
+ int cached = sm_count.load(std::memory_order_relaxed);
+ if (cached == 0) {
+ int device_id = 0;
+ cudaGetDevice(&device_id);
+ cudaDeviceProp device_prop{};
+ cudaGetDeviceProperties(&device_prop, device_id);
+ cached = device_prop.multiProcessorCount;
+ sm_count.store(cached, std::memory_order_relaxed);
+ }
+ return cached;
+}(You’d also need to #include <atomic> in this header.)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline int GetCudaMultiProcessorCount() { | |
| static int sm_count = 0; | |
| if (sm_count == 0) { | |
| int device_id; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| sm_count = device_prop.multiProcessorCount; | |
| } | |
| return sm_count; | |
| inline int GetCudaMultiProcessorCount() { | |
| static std::atomic<int> sm_count{0}; | |
| int cached = sm_count.load(std::memory_order_relaxed); | |
| if (cached == 0) { | |
| int device_id = 0; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop{}; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| cached = device_prop.multiProcessorCount; | |
| sm_count.store(cached, std::memory_order_relaxed); | |
| } | |
| return cached; | |
| } |
🤖 Prompt for AI Agents
In include/flashinfer/utils.cuh around lines 292-301, the function uses a
non-synchronized static sm_count that creates a data race and also permanently
caches the SM count for whichever device was current on first call; fix by
making the lookup thread-safe and per-device: first call cudaGetDevice() to get
the current device id, then fetch properties if no cached value exists for that
device; protect the cache with thread-safety (e.g., use std::call_once
per-device, or a mutex + an unordered_map<int,int> keyed by device id, or an
atomic<int> array sized by device count) and include the required headers
(<mutex>, <unordered_map> or <atomic>); return the cached per-device SM count
after initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)
56-69: Guard RMSNorm fusion against missingresidual_inand validate its shape.The precondition at line 56 enforces presence of
residual_out,gamma, andepsilonwhenrmsnorm_fusion=true, but omitsresidual_in. The kernel will dereferenceresidualInPtrunconditionally whenRMSNormFusionis true, causing undefined behavior ifresidual_inis absent.Additionally, shape validation (lines 61-68) only covers
residual_outandgamma;residual_inis not validated.Extend the precondition to include
residual_in:- TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + TVM_FFI_ICHECK((residual_out.has_value() && residual_in.has_value() && + gamma.has_value() && epsilon.has_value()) || !rmsnorm_fusion) - << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + << "residual_out, residual_in, gamma, and epsilon must be provided if rmsnorm_fusion is true";Add shape validation for
residual_inwithin theif (rmsnorm_fusion)block:if (rmsnorm_fusion) { TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) << ")"; + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dim + << ") but got (" << residual_in.value().size(0) << ", " + << residual_in.value().size(1) << ")"; TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" << gamma.value().size(0) << ")"; }flashinfer/comm/trtllm_mnnvl_ar.py (1)
331-332: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (see line 91 incsrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM.Apply this diff:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
502-504: Clarify assertion for legacy API compatibility.The assertion at lines 502-504 will fail with a cryptic message if
wait_for_results=Falseis passed. Since this is deprecated legacy code, the assertion is reasonable, but consider improving the error message for clarity:- assert wait_for_results and (out is not None), ( - "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." - ) + if not wait_for_results or out is None: + raise ValueError( + "Legacy trtllm_mnnvl_all_reduce requires wait_for_results=True and a valid output tensor. " + "Please use the new trtllm_mnnvl_allreduce API instead." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (3)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-108)trtllm_mnnvl_allreduce_fusion(31-37)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
258-260: Avoid specifying long messages outside the exception class
(TRY003)
265-267: Avoid specifying long messages outside the exception class
(TRY003)
335-337: Avoid specifying long messages outside the exception class
(TRY003)
339-341: Avoid specifying long messages outside the exception class
(TRY003)
343-345: Avoid specifying long messages outside the exception class
(TRY003)
349-351: Avoid specifying long messages outside the exception class
(TRY003)
355-357: Avoid specifying long messages outside the exception class
(TRY003)
497-499: Avoid specifying long messages outside the exception class
(TRY003)
568-570: Avoid specifying long messages outside the exception class
(TRY003)
574-576: Avoid specifying long messages outside the exception class
(TRY003)
579-581: Avoid specifying long messages outside the exception class
(TRY003)
583-585: Avoid specifying long messages outside the exception class
(TRY003)
588-590: Avoid specifying long messages outside the exception class
(TRY003)
593-595: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_trtllm_mnnvl_allreduce.py
280-280: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
226-232: Add workspace capacity check to prevent buffer overflow.The new
trtllm_mnnvl_allreducefunction doesn't verify that the input tensor fits within the allocated workspace buffer. A previous review comment suggested adding a check similar to the legacy API'sif inp.shape[0] > buffer_Mvalidation.While the author questioned whether this should be on the execution path, buffer overflow can cause crashes or silent memory corruption. Consider adding a validation check:
required_size = MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy ) if required_size > workspace.buffer_size_bytes: raise ValueError( f"Input tensor requires {required_size} bytes but workspace only has " f"{workspace.buffer_size_bytes} bytes. Please increase workspace size." )Based on past review comments, the maintainer questioned if this check should be on the execution path. If this is intentionally omitted for performance, please document this as a user responsibility in the docstring.
| if fusion: | ||
| # Fused case: AllReduce + Residual Add + RMS Norm | ||
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | ||
| residual_out = allreduce_result + residual # Add residual | ||
| norm_out = rmsnorm( | ||
| residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Epsilon mismatch between test and kernel default.
Line 263 uses torch.finfo(dtype).eps for the reference RMSNorm calculation, but the kernel uses a default of 1e-5 (see csrc/trtllm_mnnvl_allreduce.cu line 91). This inconsistency will cause test failures or require overly loose tolerances.
Since flashinfer/comm/trtllm_mnnvl_ar.py also incorrectly defaults to torch.finfo(input.dtype).eps at line 332, these two issues are related. Once the main API is fixed to use 1e-5, update this test accordingly:
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
- residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
+ residual_out, norm_weight, 1e-5, enable_pdl=False
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if fusion: | |
| # Fused case: AllReduce + Residual Add + RMS Norm | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| residual_out = allreduce_result + residual # Add residual | |
| norm_out = rmsnorm( | |
| residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False | |
| ) | |
| if fusion: | |
| # Fused case: AllReduce + Residual Add + RMS Norm | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| residual_out = allreduce_result + residual # Add residual | |
| norm_out = rmsnorm( | |
| residual_out, norm_weight, 1e-5, enable_pdl=False | |
| ) |
🤖 Prompt for AI Agents
In tests/comm/test_trtllm_mnnvl_allreduce.py around lines 258 to 264, the test
uses torch.finfo(dtype).eps as the epsilon for the reference RMSNorm but the
kernel defaults to 1e-5; change the test to pass epsilon=1e-5 to rmsnorm so the
reference matches the kernel default (and after fixing
flashinfer/comm/trtllm_mnnvl_ar.py to default to 1e-5, ensure this test
continues to use that same 1e-5 constant for consistency).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)
566-664: Close remaining POSIX FDs in IPC path to avoid leaksIn the POSIX handle path of
_alloc_mn_mcast_mem, a few FDs are still never closed:
local_shareable_uc_handlereturned bycuMemExportToShareableHandle(line 958) is used in the IPC ring allgather but never closed.- During the ring, each rank sends its
local_shareable_uc_handleto all peers, including itself. The self‑recv forp == group_rankpopulatesall_shareable_uc_handles[self.group_rank], but that FD is never imported (due toif p != self.group_rank) and also never closed.You already close imported POSIX FDs after
cuMemImportFromShareableHandleand close the multicast FD after import; closing the remaining two FDs will complete the cleanup and prevent per‑allocation FD leaks in long‑running jobs.One way to fix this:
if ( self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ): # All-gather fabric handles all_shareable_uc_handles = self.comm_backend.allgather( local_shareable_uc_handle.data ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() # Send to peer at offset i dest_rank = (self.group_rank + i) % self.group_size self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) # Receive from peer at offset -i src_rank = (self.group_rank + self.group_size - i) % self.group_size all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close our exported UC handle FD and the self-received FD + os.close(local_shareable_uc_handle) + if all_shareable_uc_handles[self.group_rank] is not None: + os.close(all_shareable_uc_handles[self.group_rank])The existing per‑peer close after import:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(all_shareable_uc_handles[p])and the multicast close:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(shareable_mc_handle)can stay as‑is.
Together with the new
__del__logic callingself._ipc_socket.close(), this fully addresses the descriptor‑leak concern in the IPC path.Also applies to: 957-1005, 1008-1055
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
233-271: Align reference RMSNorm epsilon with kernel default (still usingtorch.finfo(dtype).eps)
prepare_test_datastill usestorch.finfo(dtype).epsas the epsilon for the reference RMSNorm:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )while the kernel and test harness default to
eps = 1e-5(seerun_mnnvl_ar_fulland the C++ FFI wrapper’sparams.epsilondefault). This inconsistency can mask subtle discrepancies behind loose tolerances or cause avoidable test drift.To keep the reference path exactly aligned with the implementation, switch this to the same constant:
- norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False - ) + norm_out = rmsnorm( + residual_out, + norm_weight, + 1e-5, + enable_pdl=False, + )(or better, reuse the same
epsvalue passed intorun_mnnvl_ar_fullto avoid hard‑coding the constant twice).
🧹 Nitpick comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
100-114: Ensure epsilon defaults stay consistent with Python API and testsHere
params.epsilonfalls back to1e-5when the Optionalepsilonis not provided:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;The Python wrapper in
flashinfer/comm/trtllm_mnnvl_ar.pyand the tests intests/comm/test_trtllm_mnnvl_allreduce.pyshould use the same default to avoid silent discrepancies between the kernel and reference paths. The core test harness already setseps = 1e-5; the remaining mismatch is in the reference RMSNorm computation (seeprepare_test_data), which still usestorch.finfo(dtype).eps.flashinfer/comm/mnnvl.py (3)
132-150: Fixalloc_and_copy_to_cudareturn type vsNonebehavior
alloc_and_copy_to_cudais annotated as returningintbut still returnsNonefor an emptyhost_ptr_array:def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: return NoneCurrent call sites (
signal_padsanduc_ptrs) always pass non‑empty lists, so behavior is correct, but the annotation is misleading and could hide bugs if the function gets reused.Either make the return type explicit about
Noneor enforce non‑emptiness by raising:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: - if not host_ptr_array: - return None +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: + if not host_ptr_array: + raise ValueError("host_ptr_array must be non-empty")(or change the annotation to
Optional[int]if you prefer the sentinel behavior).
885-893: IPC opId bootstrap looks fine; consider documenting ordering guarantees
_init_ipc_socketuses an MPI‑likebcastto distribute a randomly chosenopIdfrom rank 0, then uses it to constructIpcSocketendpoints on all ranks. This nicely avoids hard‑coding operation IDs and lines up with the C++ IPC model.Given the reliance on collective barriers around
send_fd/recv_fd, it would help future maintainers to mention in a comment here that all ranks are expected to participate in the same sequence of IPC operations for a givenopId, and that mismatched usage will deadlock. The code is correct as written; this is just a documentation/clarity suggestion.
1143-1170: McastGPUBuffer workspace integration and pointer getters look consistentThe new
comm_backend_for_handle_transferparameter is threaded through toMcastDeviceMemory, and the addedget_unicast_ptrwrapper simply delegates tomcast_device_memory.get_unicast_ptr(rank). This lines up with how tests andget_allreduce_mnnvl_workspaceuse these pointers and keeps pointer access encapsulated.The placeholder buffer‑view methods (
get_multicast_buffer,get_unicast_buffer) are clearly markedNotImplementedError, so they won’t be hit accidentally. If you plan to expose tensor views later, you can implement them viacreate_tensor_from_cuda_memory.Also applies to: 1209-1212
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)include/flashinfer/utils.cuh(2 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/utils.cuh
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (14)
barrier(168-168)barrier(227-228)Get_rank(156-156)Get_rank(215-216)Get_size(159-159)Get_size(218-219)bcast(165-165)bcast(224-225)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
888-888: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
include/flashinfer/utils.cuh (1)
293-307: Thread‑safe SM count cache looks good; confirm single‑GPU‑per‑process assumptionUsing
static std::atomic<int>with relaxed loads/stores makes this helper thread‑safe and avoids the previousstatic intdata race. The comment explicitly assumes one CUDA device per process, since the cachedsm_countis never recomputed if the current device changes.If there are any call sites that may run in a multi‑GPU‑per‑process setup, consider extending this to a per‑device cache (e.g., keyed by device id) rather than a single global integer; otherwise, this implementation is fine as long as the single‑device assumption holds.
csrc/trtllm_mnnvl_allreduce.cu (1)
41-76: RMSNorm fusion validation and shape checks look correctThe updated precondition now correctly requires
residual_in,residual_out,gamma, andepsilonwhenrmsnorm_fusionis true, and the subsequent shape checks onresidual_in,residual_out, andgammaguard the fused path against mismatched tensors. This should prevent the fused kernels from ever seeing invalid residual/norm inputs via the FFI boundary.The overall parameter wiring into
AllReduceFusionParams(including buffer pointers and flags) also looks consistent with the Python side.flashinfer/comm/mnnvl.py (1)
781-790: Good: IPC socket is now closed in destructorThe addition of:
if hasattr(self, "_ipc_socket"): self._ipc_socket.close()inside
__del__ensures the Unix domain socket is closed and, for non‑abstract sockets, the filesystem entry is unlinked. This addresses the earlier socket‑leak concern while remaining safe when construction fails before_ipc_socketis set.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
16-103: Test harness refactor cleanly exercises both refactored and legacy APIsThe new helpers (
row_linear_residual_norm_fusion_forward,_legacy,run_mnnvl_ar_full) and parametrized tests (test_mnnvl_allreduce_refactored,test_mnnvl_allreduce_legacy) do a good job of:
- Sharing core logic between fused and non‑fused paths.
- Covering both the new workspace‑based API and the legacy pointer‑based API.
- Exercising a variety of sequence lengths, dtypes, and hidden sizes.
- Integrating MPI barriers and rank‑aware logging to make multi‑rank failures diagnosable.
Once the epsilon alignment in
prepare_test_datais fixed, this test suite should give solid coverage for the new fused implementation and its backward‑compatibility guarantees.Also applies to: 274-397, 439-465
flashinfer/comm/trtllm_mnnvl_ar.py
Outdated
| comm_backend: Optional[CommBackend] = None, | ||
| ): | ||
| """ | ||
| Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way we can check this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to update the doc. Fixed.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide guidance for buffer_size_in_bytes? E.g., in function of number of tokens and hidden size? Or just refer to get_required_buffer_size_bytes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just refer to get_required_buffer_size_bytes
| comm_backend, | ||
| ) | ||
|
|
||
| # We use FP32 for sentinel value regardless of the real dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Before we used the real dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using LDG.128 for all data read/write, and if the allocation is 4-byte (word) aligned, reading/writing each word can be considered atomic. Thus, it is sufficient to check in FP32 granularity regardless of the dtype. This will simplify the buffer management as decoupling the sentinel value from the dtype.
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want this check. I recently added it because it did bite others.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to replace "buffer_size_in_bytes" by the parameters that get_required_buffer_size_bytes takes, and just call this from the init function. Seems more user friendly.
If you do want to just allocate a blob of memory, we could still have buffer_size_in_bytes as an addtional parameter that would override whatever is calculated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that is a good design as it might give the user an impression that the allocated workspace will ONLY support the set of parameters (max_num_tokens, hidden_dim, dtype, strategy)
But actually, the workspace usage is quite flexible and as long as the required workspace size is smaller than the allocation, it will work. Thus, the intended usage is the user checks the required workspace size (or we can check it when calling the allreduce function, but at some cost) before using it.
csrc/trtllm_mnnvl_allreduce.cu
Outdated
| Optional<TensorView> out) { | ||
| cudaSetDevice(in.device().device_id); | ||
| auto stream = get_stream(in.device()); | ||
| // FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is a problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment removed.
flashinfer/comm/trtllm_mnnvl_ar.py
Outdated
| AUTO = 99 | ||
|
|
||
| @staticmethod | ||
| def is_one_shot( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this to "heuristic_for_one_shot" or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
361-362: Critical: Restore epsilon default to 1e-5 to match kernel.This epsilon fallback was flagged as critical in a previous review but remains unresolved. Using
torch.finfo(input.dtype).epssets epsilon to approximately 1e-3 for fp16, diverging from the kernel's built-in 1e-5 default (seecsrc/trtllm_mnnvl_allreduce.culine 96). This materially alters RMSNorm results and breaks compatibility with TensorRT-LLM.Apply this fix:
- if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + if epsilon is None: + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
118-136: Consider replacing@functools.cacheon instance method.Using
@functools.cacheon an instance method can prevent the instance from being garbage collected, leading to memory leaks. Since this method takesselfas the first parameter, the cache will hold references to the instance.Consider either:
- Making this a standalone function that takes workspace parameters explicitly
- Using
@functools.lru_cache(maxsize=...)with a reasonable limit- Implementing manual caching in the instance if needed
Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1135-1215)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1115-1132)lamport_initialize(1174-1175)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(192-243)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
282-284: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
303-305: Avoid specifying long messages outside the exception class
(TRY003)
365-367: Avoid specifying long messages outside the exception class
(TRY003)
369-371: Avoid specifying long messages outside the exception class
(TRY003)
373-375: Avoid specifying long messages outside the exception class
(TRY003)
379-381: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
398-400: Avoid specifying long messages outside the exception class
(TRY003)
528-530: Avoid specifying long messages outside the exception class
(TRY003)
599-601: Avoid specifying long messages outside the exception class
(TRY003)
605-607: Avoid specifying long messages outside the exception class
(TRY003)
610-612: Avoid specifying long messages outside the exception class
(TRY003)
614-616: Avoid specifying long messages outside the exception class
(TRY003)
619-621: Avoid specifying long messages outside the exception class
(TRY003)
624-626: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
29-113: LGTM! Fusion entry point is well-structured.The refactored entry point properly validates all inputs, including the RMSNorm fusion parameters that were flagged in previous reviews. The dispatch logic cleanly selects between oneshot and twoshot strategies, and error messages are clear and actionable.
flashinfer/comm/trtllm_mnnvl_ar.py (3)
30-48: Strategy enum and heuristic look good.The
MNNVLAllreduceFusionStrategyenum provides a clear interface for selecting between oneshot and twoshot approaches, with a sensible AUTO mode that uses an empirically-derived threshold.
250-326: Buffer size validation properly implemented.The function now includes the buffer size check that was requested in previous reviews (lines 300-305), preventing potential out-of-bounds access. Input validation is comprehensive and error messages are clear.
422-646: Deprecation strategy is well-executed.The legacy APIs are properly marked with
@deprecateddecorators and include clear migration guidance. The wrappers correctly redirect to the new fusion-based implementations, maintaining backward compatibility while encouraging adoption of the improved APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-149: Fix return type inconsistency.The function returns
Noneat line 137 whenhost_ptr_arrayis empty, but the return type annotation at line 132 indicatesint. This creates a type mismatch.Consider one of these fixes:
Option 1: Return
Optional[int]and update callers to handle None:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:Option 2: Raise an error instead of returning None:
if not host_ptr_array: - return None + raise ValueError("host_ptr_array cannot be empty")
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
377-378: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (seetrtllm_mnnvl_allreduce_fusionincsrc/trtllm_mnnvl_allreduce.culine ~35:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking numerical parity.Apply this diff to fix:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
134-152: Consider alternatives to@functools.cacheon instance methods.Using
@functools.cache(or@lru_cache) on instance methods can prevent garbage collection of instances because the cache holds references to bound methods, which in turn hold references toself. SinceMNNVLAllreduceFusionWorkspaceinstances are likely long-lived in typical usage, this may be acceptable, but consider these alternatives:
- Use
@functools.lru_cache(maxsize=128)to limit cache growth- Move caching logic to a module-level cache keyed on relevant parameters
- Document the caching behavior and its memory implications
Based on learnings
Apply this diff if you want to limit cache size:
- @functools.cache + @functools.lru_cache(maxsize=128) def is_buffer_size_sufficient(flashinfer/comm/mnnvl.py (2)
640-654: Prefix unused unpacked variables with underscore.The variables
msg,flags, andaddrfromrecvmsgare unpacked but never used. Prefix them with_to indicate they're intentionally ignored.Apply this diff:
- msg, ancdata, flags, addr = self.sock.recvmsg( + _msg, ancdata, _flags, _addr = self.sock.recvmsg(
893-900: Consider usingsecretsmodule for opId generation.While cryptographic randomness is not strictly required for socket naming, using
secrets.randbelow(2**64)instead ofrandom.randintprovides better collision resistance if multiple jobs run concurrently on the same node.Apply this diff:
+import secrets + def _init_ipc_socket(self): if self.group_rank == 0: - # Gnerate the opId - opId = random.randint(0, 2**64 - 1) + # Generate the opId + opId = secrets.randbelow(2**64)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/mnnvl.py(19 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1143-1224)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1123-1140)lamport_initialize(1183-1184)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1222-1224)get_unicast_ptr(861-869)get_unicast_ptr(1218-1220)get_multicast_ptr(871-875)get_multicast_ptr(1214-1216)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
896-896: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
298-300: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
319-321: Avoid specifying long messages outside the exception class
(TRY003)
381-383: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
389-391: Avoid specifying long messages outside the exception class
(TRY003)
395-397: Avoid specifying long messages outside the exception class
(TRY003)
401-403: Avoid specifying long messages outside the exception class
(TRY003)
414-416: Avoid specifying long messages outside the exception class
(TRY003)
544-546: Avoid specifying long messages outside the exception class
(TRY003)
615-617: Avoid specifying long messages outside the exception class
(TRY003)
621-623: Avoid specifying long messages outside the exception class
(TRY003)
626-628: Avoid specifying long messages outside the exception class
(TRY003)
630-632: Avoid specifying long messages outside the exception class
(TRY003)
635-637: Avoid specifying long messages outside the exception class
(TRY003)
640-642: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
266-341: LGTM!The function correctly validates inputs, selects strategy, checks buffer size sufficiency (addressing past review feedback), and invokes the fusion kernel with appropriate parameters.
flashinfer/comm/mnnvl.py (1)
788-789: LGTM!The IPC socket cleanup correctly uses
hasattrto check for existence before closing, addressing the file descriptor leak concern from past reviews.
📌 Description
This PR porting all changes in TensorRT-LLM#8018 into Flashinfer.
Apart from the changes mentioned in the original PR, this PR also introduce new API interface as
trtllm_mnnvl_allreduceandtrtllm_mnnvl_fused_allreduce_add_rmsnormto replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside.This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle.
@wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer.
The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.