Skip to content

Conversation

@timlee0212
Copy link
Contributor

@timlee0212 timlee0212 commented Nov 20, 2025

📌 Description

This PR porting all changes in TensorRT-LLM#8018 into Flashinfer.

Apart from the changes mentioned in the original PR, this PR also introduce new API interface as trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm to replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside.

This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle.

@wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer.

The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Fused all-reduce with optional RMSNorm fusion, one-shot/two-shot strategies, new workspace and high-level Python APIs; IPC-based POSIX FD transfer and pluggable comm backend for handle exchange.
  • Improvements

    • Lamport-buffer fusion path for better performance and memory efficiency; stronger input/output validation and deprecation wrappers guiding migration.
  • Tests

    • MPI-aware tests added to cover fused and legacy all-reduce flows.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replaces the legacy MNNVL all-reduce with a fused lamport-buffer allreduce exposing trtllm_mnnvl_allreduce_fusion, adds optional RMSNorm fusion and one-/two-shot dispatch, introduces IPC FD transfer via IpcSocket, adjusts CUDA headers/kernels and C++/Python parameter structs, extends Python workspace/backends, and updates MPI-aware tests.

Changes

Cohort / File(s) Summary
CUDA entry
csrc/trtllm_mnnvl_allreduce.cu
Replaced trtllm_mnnvl_all_reduce with trtllm_mnnvl_allreduce_fusion; expanded public signature and params (residual_in/out, gamma, epsilon, rmsnorm_fusion, launch_with_pdl, use_oneshot); input/output validation updated; builds AllReduceFusionParams and conditionally dispatches oneshot/twoshot fusion paths.
CUDA header / kernels
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Replaced AllReduceParams/RMSNormParams with AllReduceFusionParams; added utils (packed loads, lamport buffer layout, grid config); implemented lamport-buffer-based oneshot/twoshot fusion kernels with integrated RMSNorm fusion paths and dispatch templates.
CUDA utils
include/flashinfer/utils.cuh
Added thread-safe cached helper flashinfer::GetCudaMultiProcessorCount() using std::atomic.
Python IPC & MNNVL backend
flashinfer/comm/mnnvl.py
Added IpcSocket for UNIX-domain FD transfer; extended CommBackend with bcast/barrier and MPI implementation; McastDeviceMemory/McastGpuBuffer accept comm_backend_for_handle_transfer, track _shareable_handle_type (FABRIC/POSIX), initialize IPC fallback, and switch export/import to use either fabric or IPC FD exchange; normalized alloc_and_copy_to_cuda to return int.
High-level Python API & workspace
flashinfer/comm/trtllm_mnnvl_ar.py
Added MNNVLAllreduceFusionStrategy enum and MNNVLAllreduceFusionWorkspace; exposed trtllm_mnnvl_allreduce_fusion kernel symbol; added trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm high-level APIs selecting one-/two-shot strategies; deprecated legacy workspace/APIs and added compatibility wrappers.
Tests
tests/comm/test_trtllm_mnnvl_allreduce.py
Added MPI-aware orchestration/barriers and logging, prepare_test_data helper, extended/refactored fusion and legacy test flows (test_mnnvl_allreduce_refactored, test_mnnvl_allreduce_legacy), and improved cleanup/traceback handling.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant PyAPI as Python API / Workspace
    participant Strategy as Strategy Selector
    participant Comm as Comm Backend / IpcSocket / MPI
    participant Kernel as CUDA Kernel (fusion)
    participant Out as Output Buffers

    App->>PyAPI: call trtllm_mnnvl_allreduce(...) / fused variant
    PyAPI->>PyAPI: validate inputs, prepare workspace & outputs
    PyAPI->>Strategy: select ONESHOT / TWOSHOT (AUTO may inspect sizes)
    PyAPI->>Comm: exchange/share handles (MPI bcast/barrier or IpcSocket FD exchange)
    PyAPI->>Kernel: invoke trtllm_mnnvl_allreduce_fusion with params (rmsnorm_fusion, use_oneshot,...)
    rect rgb(235,245,255)
      Kernel->>Kernel: lamport-stage broadcast & rank reduction
      alt RMSNorm fusion enabled
        Kernel->>Kernel: compute RMS, apply gamma, add residuals
      end
      Kernel->>Out: write output (and residual_out if present)
    end
    Out-->>PyAPI: return tensor(s)
    PyAPI-->>App: deliver result(s)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • Lamport-buffer kernel correctness, synchronization, and multi-rank coordination.
    • One-shot vs two-shot dispatch correctness and propagation of use_oneshot.
    • Buffer pointer lifecycle and layout (bufferPtrsDev, multicastPtr, bufferPtrLocal, bufferFlags).
    • RMSNorm fusion numeric paths (FP16/BF16/CGA) and epsilon/gamma/residual handling.
    • IPC FD transfer semantics, IpcSocket race/cleanup, and MPI/IPC coordination.
    • Python/CUDA parameter wiring and deprecated wrapper correctness.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • yongwww
  • yzh119
  • cyx-6
  • wenscarl
  • IwakuraRein
  • bkryu
  • jiahanc

Poem

🐰 I hop through lamport lanes where token shards play,
I pass file descriptors by moonlight, across ranks they stray,
One-shot or two-shot, fusion tunes the run,
Gamma, residuals, and RMSNorm dance till math is done,
Buffers hum and kernels sing — hooray for fast all-reduce! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Refactor trtllm_mnnvl_allreduce' clearly and concisely summarizes the primary change in this PR, which is a refactoring of the trtllm_mnnvl_allreduce component.
Description check ✅ Passed The PR description provides comprehensive context, explains the port from TensorRT-LLM, describes new API interfaces, workspace changes, and IPC socket support. It references related issues and indicates pre-commit checks and tests were completed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive refactoring of the multi-node NVLink (MNNVL) all-reduce system within FlashInfer. It unifies the all-reduce and RMSNorm operations into a single, highly configurable C++ kernel, exposed through intuitive new Python APIs. A key improvement is the new workspace management class, which automates and optimizes buffer allocation. Furthermore, the PR adds crucial support for IPC Socket-based handle transfer, broadening compatibility to hardware environments like DGX machines. These changes collectively enhance the flexibility, performance, and overall robustness of distributed computations.

Highlights

  • Refactored MNNVL All-Reduce Implementation: The core multi-node NVLink (MNNVL) all-reduce logic has been significantly refactored, consolidating all-reduce and RMSNorm functionalities into a single, flexible C++ kernel (trtllm_mnnvl_allreduce_fusion).
  • New Python API Interfaces: New Python APIs, trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm, have been introduced to provide clearer and more flexible usage for non-fused and fused all-reduce operations, respectively.
  • Enhanced Workspace Management: A new MNNVLAllreduceFusionWorkspace class now handles buffer allocation and management, simplifying the process for users and ensuring robust workspace sizing based on problem dimensions and chosen strategy (one-shot or two-shot).
  • IPC Socket Support for Handle Transfer: Support for IPC Socket-based device memory bootstrap has been added, enabling MNNVL operations on DGX machines and other environments that may not support fabric handles.
  • Deprecated Old APIs: The previous get_allreduce_mnnvl_workspace, trtllm_mnnvl_all_reduce, and trtllm_mnnvl_fused_allreduce_rmsnorm APIs are now marked as deprecated, though they remain functional for backward compatibility by internally calling the new fusion kernel.
  • Performance Optimizations: The CUDA kernels have been optimized with one-shot and two-shot strategies, refined Lamport synchronization, and dynamic grid configuration adjustments for improved efficiency across various problem sizes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant refactoring of the MNNVL all-reduce implementation, introducing a new, cleaner API with a dedicated workspace manager class, and adding support for IPC sockets for single-node communication. The changes are extensive and substantially improve the code's structure and capabilities. My review focuses on ensuring backward compatibility is fully maintained as intended, removing leftover debug code, improving memory usage efficiency, adding a critical safety check for buffer sizes in the new API, and suggesting a minor precision improvement in a CUDA kernel.

Comment on lines +229 to +235
def trtllm_mnnvl_allreduce(
input: torch.Tensor,
workspace: MNNVLAllreduceFusionWorkspace,
launch_with_pdl: bool,
output: Optional[torch.Tensor] = None,
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new trtllm_mnnvl_allreduce function is missing a check to ensure that the input tensor fits within the allocated workspace. The old API had a check like if inp.shape[0] > buffer_M: raise ValueError(...). A similar check should be added here to prevent potential out-of-bounds memory access, which could lead to crashes or incorrect results. The required buffer size depends on the strategy (one-shot vs. two-shot) and can be calculated using MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this check to be on the execution path? Or should we assuming it is the user's liability to ensure it does not overflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want this check. I recently added it because it did bite others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

self.mcast_device_memory.lamport_initialize(rank, dtype)

def get_mc_buffer(
def get_multicast_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Renaming get_mc_buffer to get_multicast_buffer is a breaking change. The pull request description states an intention to maintain backward compatibility. To align with this, please consider re-introducing get_mc_buffer as a deprecated function that calls get_multicast_buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is used internally, and left as a placeholder but not implemented. Thus, a breaking changes is fine. Tag @nvmbreughe for confirmation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/comm/mnnvl.py (1)

132-149: alloc_and_copy_to_cuda return type and empty-input behavior are inconsistent

The function is annotated as returning int but returns None when host_ptr_array is empty. Callers currently pass non‑empty lists, but this mismatch can trip type checkers and hide bugs if an empty list is ever passed.

Either tighten behavior or relax the signature, for example:

  • If empty input is invalid, raise:
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
    if not host_ptr_array:
        raise ValueError("host_ptr_array must be non-empty")
  • Or, if you want the sentinel, change the annotation to int | None and document the None case.
tests/comm/test_trtllm_mnnvl_allreduce.py (1)

328-427: Move allgather() and final mpi_barrier() to finally block to ensure all ranks participate in collectives

Lines 414 and 434 create a deadlock risk in error scenarios. The allgather() at line 414 is inside the except block, so only ranks that hit an exception call it. Meanwhile, the mpi_barrier() at line 434 is unconditionally called after try/except/finally. If an error occurs on some but not all ranks, failing ranks block in allgather() waiting for non-failing ranks that never enter the except block, while non-failing ranks block in the final barrier—both unable to proceed.

Move the allgather() call and final mpi_barrier() to the finally block to ensure all ranks participate in these collective operations:

rank_failed = False
try:
    ...
except Exception as e:
    rank_failed = True
    failure_message = ...
    print(failure_message)
    import traceback
    print(traceback.format_exc())
    raise
finally:
    all_failures = MPI.COMM_WORLD.allgather(rank_failed)
    if any(all_failures):
        failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
        if rank == 0:
            print(f"Test failed on ranks: {failed_ranks}")
    if "workspace" in locals():
        del workspace
    trtllm_mnnvl_ar.mpi_barrier()

This applies to line 328–426 (main try/except) and line 434 (final barrier).

🧹 Nitpick comments (8)
flashinfer/comm/mnnvl.py (1)

640-655: Minor polish: unused recvmsg outputs and predictable opId

Two small, non‑blocking cleanups:

  • In IpcSocket.recv_fd(), the unpacked msg, flags, and addr from recvmsg are unused. Renaming them to _msg, _flags, _addr will make that explicit and silence linters:
_msg, ancdata, _flags, _addr = self.sock.recvmsg(...)
  • opId for the socket name is generated with random.randint. Since it’s only used as a uniqueness hint (not security‑sensitive), this is fine; if you want to appease S311 you could switch to secrets.randbits(64) or document that it’s non‑cryptographic.

Both are optional, but would make static analysis quieter.

Also applies to: 885-889

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)

23-25: Explicitly include <array> and <tuple>, and guard adjustGridConfig against smCount == 0

Within this header:

  • LamportBufferLayout, LamportFlags, PackedVec, and several kernels use std::array.
  • adjustGridConfig returns std::tuple<int, int, int> and callers use std::get.

But only <type_traits> is included; <array> and <tuple> are currently pulled in (if at all) via transitive includes, which is fragile.

Also, adjustGridConfig relies on GetCudaMultiProcessorCount():

int smCount = GetCudaMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) {
  ...
}

If GetCudaMultiProcessorCount() ever returns 0 (e.g., CUDA error or misconfiguration), this loop will keep shrinking clusterSize and inflating blockSize in a somewhat opaque way.

Suggestions:

  • Add explicit includes at the top of the header:
#include <array>
#include <tuple>
  • Make adjustGridConfig robust to a 0 or negative SM count by early‑clamping:
int smCount = GetCudaMultiProcessorCount();
if (smCount <= 0) {
  // Fall back to single-SM configuration
  clusterSize = 1;
  blockSize = std::min(threadsNeeded, 1024);
  return {blockSize, clusterSize, loadsPerThread};
}

This keeps the fused path predictable even if the helper cannot obtain a valid SM count.

Also applies to: 54-177, 143-163, 291-313, 348-359, 385-419, 449-497


509-651: Confirm lamport clear / wait protocol assumptions for oneshot kernel

The oneshot fused kernel uses LamportFlags as follows:

  • Out‑of‑bounds threads call ctaArrive() then clearDirtyLamportBuf() and return.
  • In‑bounds threads:
    • write their shard into the multicast lamport buffer,
    • call ctaArrive() again,
    • then call clearDirtyLamportBuf() and spin on the Lamport buffers until all entries are non‑negZero.

This protocol assumes:

  • Every thread in the grid calls clearDirtyLamportBuf() exactly once per iteration.
  • Buffer flags and bytesToClear are correctly initialized to match the configured numTokens * tokenDim * WorldSize.

Given that this is a direct Lamport port, the logic looks consistent, but the protocol is subtle. I’d recommend:

  • Double‑checking the initialization of buffer_flags in MNNVLAllreduceFusionWorkspace matches the expectations here (current index, dirty index, bytes per buffer, and stage counts).
  • Adding a brief comment near the kernel launch documenting that buffer_flags must follow the [cur, dirty, bytes_per_buffer, dirty_num_stages, bytes_to_clear[4], access_ptr] layout used by LamportFlags.

No code change strictly required, but the invariants are nontrivial and worth locking down in comments/tests.


754-885: Two‑shot path & RMSNorm fusion: validate world sizes and loads‑per‑thread bounds

The two‑shot kernels and dispatchers introduce several constraints:

  • twoshotAllreduceFusionDispatch<T> only supports nRanks in {2, 4, 8, 16, 32, 64} and enforces tokenDim % (sizeof(float4) / sizeof(T)) == 0.
  • rmsNormLamport is instantiated with LoadsPerThread in [1, 8] and uses float4 loads into shared memory; dynamic shared memory is sized as 3 * rnBlockSize * iters * sizeof(T) and indexed accordingly.

The implementation looks coherent, but a few invariants are implicit:

  • MNNVLTwoShotStage::NUM_STAGES must stay in sync with the LamportFlags<float4> usage and the two bytes_to_clear entries in waitAndUpdate.
  • rnLoadsPerThread retrieved from adjustGridConfig must remain in [1, 8]; the default: branch already errors if it’s out of range, which is good.
  • rnClusterSize from adjustGridConfig is assumed to be <= 8 given __shared__ float sharedVal[8]; in the RMSNorm kernel.

Given these contracts, I’d suggest:

  • Adding asserts (or comments) that rnClusterSize <= 8 when CGA is used, to guard future changes to adjustGridConfig.
  • Extending tests to cover the corner cases where tokenDim is just at or above the supported boundary (e.g., maximum hidden size and multiple world sizes) so we don’t regress the FLASHINFER_CHECK conditions.

Functionally the code looks sound; this is mainly about making the implicit constraints explicit.

Also applies to: 898-959, 1062-1219

csrc/trtllm_mnnvl_allreduce.cu (1)

99-107: Error message still mentions “twoshot” even for oneshot path

Regardless of use_oneshot, the failure message says:

TVM_FFI_ICHECK(status == cudaSuccess)
    << "twoshot_allreduce_dispatch_world_size failed with error code "
    << cudaGetErrorString(status);

This is slightly misleading when the oneshot dispatch is used. Consider making the message neutral (e.g., “allreduce_fusion_dispatch failed…”) or switching on use_oneshot to provide a more accurate label. Behavior is otherwise fine.

tests/comm/test_trtllm_mnnvl_allreduce.py (2)

232-270: Use the same eps for reference RMSNorm as the fused kernel

In prepare_test_data, the fused reference path uses:

norm_out = rmsnorm(
    residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

But the actual fused kernel is driven by the eps argument passed into row_linear_residual_norm_fusion_forward (eps = 1e-5 in run_mnnvl_ar_full).

To keep the reference as close as possible to the fused implementation (and not rely on loose tolerances), consider:

def prepare_test_data(..., fusion: bool, eps: float):
    ...
    if fusion:
        ...
        norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)

and threading eps through the call sites.


273-281: Annotate legacy_explicit_workspace_bytes as optional

Ruff’s RUF013 warning here is valid:

def run_mnnvl_ar_full(...,
    legacy_explicit_workspace_bytes: int = None,
    legacy_api: bool = False,
):

Changing the signature to make the optionality explicit improves readability and typing:

from typing import Optional

def run_mnnvl_ar_full(
    ...,
    legacy_explicit_workspace_bytes: Optional[int] = None,
    legacy_api: bool = False,
) -> None:
    ...

or, in Python 3.10+:

legacy_explicit_workspace_bytes: int | None = None
flashinfer/comm/trtllm_mnnvl_ar.py (1)

203-205: Drop debug print from hot path.
This unconditional print will spam stdout for every call to the fused kernel. Please remove it or guard it behind a proper debug logger.

-        print(
-            f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}"
-        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0753095 and a2670e8.

📒 Files selected for processing (6)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/mnnvl.py (18 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2 hunks)
  • include/flashinfer/utils.cuh (1 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (301-391)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (229-298)
  • get_allreduce_mnnvl_workspace (398-451)
  • get_required_buffer_size_bytes (116-141)
flashinfer/comm/mnnvl.py (10)
  • barrier (168-168)
  • barrier (227-228)
  • bcast (165-165)
  • bcast (224-225)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-222)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1121-1201)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1101-1118)
  • lamport_initialize (1160-1161)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (31-109)
  • trtllm_mnnvl_allreduce_fusion (31-37)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
🪛 Ruff (0.14.5)
tests/comm/test_trtllm_mnnvl_allreduce.py

279-279: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

flashinfer/comm/trtllm_mnnvl_ar.py

74-76: Avoid specifying long messages outside the exception class

(TRY003)


261-263: Avoid specifying long messages outside the exception class

(TRY003)


268-270: Avoid specifying long messages outside the exception class

(TRY003)


338-340: Avoid specifying long messages outside the exception class

(TRY003)


342-344: Avoid specifying long messages outside the exception class

(TRY003)


346-348: Avoid specifying long messages outside the exception class

(TRY003)


352-354: Avoid specifying long messages outside the exception class

(TRY003)


358-360: Avoid specifying long messages outside the exception class

(TRY003)


500-502: Avoid specifying long messages outside the exception class

(TRY003)


571-573: Avoid specifying long messages outside the exception class

(TRY003)


577-579: Avoid specifying long messages outside the exception class

(TRY003)


582-584: Avoid specifying long messages outside the exception class

(TRY003)


586-588: Avoid specifying long messages outside the exception class

(TRY003)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


596-598: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


885-885: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines 334 to 389
if epsilon is None:
epsilon = torch.finfo(input.dtype).eps

if len(input.shape) != 2:
raise ValueError(
f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}."
)
if len(residual_in.shape) != 2:
raise ValueError(
f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}."
)
if gamma.numel() != input.shape[1]:
raise ValueError(
f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements."
)
if output is None:
output = torch.empty_like(input)
elif len(output.shape) != 2:
raise ValueError(
f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}."
)
if residual_out is None:
residual_out = torch.empty_like(residual_in)
elif len(residual_out.shape) != 2:
raise ValueError(
f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}."
)

module = get_trtllm_mnnvl_comm_module()

use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
strategy == MNNVLAllreduceFusionStrategy.AUTO
and MNNVLAllreduceFusionStrategy.is_one_shot(
workspace.tp_size,
input.shape[0],
input.shape[1],
input.dtype,
)
)

module.trtllm_mnnvl_allreduce_fusion(
input,
workspace.mc_ptr,
workspace.uc_ptrs_dev,
workspace.uc_ptr_local,
workspace.buffer_flags,
workspace.tp_size,
workspace.rank,
True, # RMSNorm Fusion
launch_with_pdl,
use_oneshot,
output,
residual_out,
residual_in,
gamma,
epsilon,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Restore RMSNorm epsilon default to 1e-5.
Overriding epsilon with torch.finfo(...).eps replaces the kernel’s built-in 1e-5 default (see trtllm_mnnvl_allreduce_fusion in csrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM. Please keep the default at 1e-5 (or leave the argument unset so the kernel default is used).

-    if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+    if epsilon is None:
+        epsilon = 1e-5
🧰 Tools
🪛 Ruff (0.14.5)

338-340: Avoid specifying long messages outside the exception class

(TRY003)


342-344: Avoid specifying long messages outside the exception class

(TRY003)


346-348: Avoid specifying long messages outside the exception class

(TRY003)


352-354: Avoid specifying long messages outside the exception class

(TRY003)


358-360: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around lines 334-389, the code currently
overrides epsilon with torch.finfo(input.dtype).eps when epsilon is None which
changes RMSNorm behavior (especially for fp16); restore the original default by
either leaving epsilon as None so the CUDA kernel's built-in 1e-5 is used or
explicitly set epsilon = 1e-5 when epsilon is None, and remove the
torch.finfo(...) fallback so we do not substitute ~1e-3 for fp16.

Comment on lines 292 to 301
inline int GetCudaMultiProcessorCount() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make GetCudaMultiProcessorCount thread‑safe and clarify multi‑device semantics

  • static int sm_count is written without synchronization; concurrent calls from multiple host threads can cause a data race and undefined behavior.
  • The function also permanently caches the SM count of whichever device is current on the first call; if the process later switches devices, the cached value will be wrong.

Consider making the cache atomic (or using std::call_once) and, if needed, keying by device ID. For example:

- inline int GetCudaMultiProcessorCount() {
-  static int sm_count = 0;
-  if (sm_count == 0) {
-    int device_id;
-    cudaGetDevice(&device_id);
-    cudaDeviceProp device_prop;
-    cudaGetDeviceProperties(&device_prop, device_id);
-    sm_count = device_prop.multiProcessorCount;
-  }
-  return sm_count;
-}
+ inline int GetCudaMultiProcessorCount() {
+  static std::atomic<int> sm_count{0};
+  int cached = sm_count.load(std::memory_order_relaxed);
+  if (cached == 0) {
+    int device_id = 0;
+    cudaGetDevice(&device_id);
+    cudaDeviceProp device_prop{};
+    cudaGetDeviceProperties(&device_prop, device_id);
+    cached = device_prop.multiProcessorCount;
+    sm_count.store(cached, std::memory_order_relaxed);
+  }
+  return cached;
+}

(You’d also need to #include <atomic> in this header.)

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline int GetCudaMultiProcessorCount() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
inline int GetCudaMultiProcessorCount() {
static std::atomic<int> sm_count{0};
int cached = sm_count.load(std::memory_order_relaxed);
if (cached == 0) {
int device_id = 0;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop{};
cudaGetDeviceProperties(&device_prop, device_id);
cached = device_prop.multiProcessorCount;
sm_count.store(cached, std::memory_order_relaxed);
}
return cached;
}
🤖 Prompt for AI Agents
In include/flashinfer/utils.cuh around lines 292-301, the function uses a
non-synchronized static sm_count that creates a data race and also permanently
caches the SM count for whichever device was current on first call; fix by
making the lookup thread-safe and per-device: first call cudaGetDevice() to get
the current device id, then fetch properties if no cached value exists for that
device; protect the cache with thread-safety (e.g., use std::call_once
per-device, or a mutex + an unordered_map<int,int> keyed by device id, or an
atomic<int> array sized by device count) and include the required headers
(<mutex>, <unordered_map> or <atomic>); return the cached per-device SM count
after initialization.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)

56-69: Guard RMSNorm fusion against missing residual_in and validate its shape.

The precondition at line 56 enforces presence of residual_out, gamma, and epsilon when rmsnorm_fusion=true, but omits residual_in. The kernel will dereference residualInPtr unconditionally when RMSNormFusion is true, causing undefined behavior if residual_in is absent.

Additionally, shape validation (lines 61-68) only covers residual_out and gamma; residual_in is not validated.

Extend the precondition to include residual_in:

-    TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) ||
+    TVM_FFI_ICHECK((residual_out.has_value() && residual_in.has_value() &&
+                    gamma.has_value() && epsilon.has_value()) ||
                    !rmsnorm_fusion)
-        << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true";
+        << "residual_out, residual_in, gamma, and epsilon must be provided if rmsnorm_fusion is true";

Add shape validation for residual_in within the if (rmsnorm_fusion) block:

     if (rmsnorm_fusion) {
       TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
                      residual_out.value().size(1) == token_dim)
           << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
           << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
           << ")";
+      TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
+                     residual_in.value().size(1) == token_dim)
+          << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dim
+          << ") but got (" << residual_in.value().size(0) << ", "
+          << residual_in.value().size(1) << ")";
       TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
           << "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
           << gamma.value().size(0) << ")";
     }
flashinfer/comm/trtllm_mnnvl_ar.py (1)

331-332: Restore RMSNorm epsilon default to 1e-5.

Overriding epsilon with torch.finfo(input.dtype).eps replaces the kernel's built-in 1e-5 default (see line 91 in csrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM.

Apply this diff:

     if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+        epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

502-504: Clarify assertion for legacy API compatibility.

The assertion at lines 502-504 will fail with a cryptic message if wait_for_results=False is passed. Since this is deprecated legacy code, the assertion is reasonable, but consider improving the error message for clarity:

-    assert wait_for_results and (out is not None), (
-        "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead."
-    )
+    if not wait_for_results or out is None:
+        raise ValueError(
+            "Legacy trtllm_mnnvl_all_reduce requires wait_for_results=True and a valid output tensor. "
+            "Please use the new trtllm_mnnvl_allreduce API instead."
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2670e8 and 92cbd48.

📒 Files selected for processing (3)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (3)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-219)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1121-1201)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1101-1118)
  • lamport_initialize (1160-1161)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (31-108)
  • trtllm_mnnvl_allreduce_fusion (31-37)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (298-388)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (226-295)
  • get_allreduce_mnnvl_workspace (395-448)
  • get_required_buffer_size_bytes (116-141)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py

74-76: Avoid specifying long messages outside the exception class

(TRY003)


258-260: Avoid specifying long messages outside the exception class

(TRY003)


265-267: Avoid specifying long messages outside the exception class

(TRY003)


335-337: Avoid specifying long messages outside the exception class

(TRY003)


339-341: Avoid specifying long messages outside the exception class

(TRY003)


343-345: Avoid specifying long messages outside the exception class

(TRY003)


349-351: Avoid specifying long messages outside the exception class

(TRY003)


355-357: Avoid specifying long messages outside the exception class

(TRY003)


497-499: Avoid specifying long messages outside the exception class

(TRY003)


568-570: Avoid specifying long messages outside the exception class

(TRY003)


574-576: Avoid specifying long messages outside the exception class

(TRY003)


579-581: Avoid specifying long messages outside the exception class

(TRY003)


583-585: Avoid specifying long messages outside the exception class

(TRY003)


588-590: Avoid specifying long messages outside the exception class

(TRY003)


593-595: Avoid specifying long messages outside the exception class

(TRY003)

tests/comm/test_trtllm_mnnvl_allreduce.py

280-280: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

226-232: Add workspace capacity check to prevent buffer overflow.

The new trtllm_mnnvl_allreduce function doesn't verify that the input tensor fits within the allocated workspace buffer. A previous review comment suggested adding a check similar to the legacy API's if inp.shape[0] > buffer_M validation.

While the author questioned whether this should be on the execution path, buffer overflow can cause crashes or silent memory corruption. Consider adding a validation check:

required_size = MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes(
    workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
)
if required_size > workspace.buffer_size_bytes:
    raise ValueError(
        f"Input tensor requires {required_size} bytes but workspace only has "
        f"{workspace.buffer_size_bytes} bytes. Please increase workspace size."
    )

Based on past review comments, the maintainer questioned if this check should be on the execution path. If this is intentionally omitted for performance, please document this as a user responsibility in the docstring.

Comment on lines +258 to +264
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Epsilon mismatch between test and kernel default.

Line 263 uses torch.finfo(dtype).eps for the reference RMSNorm calculation, but the kernel uses a default of 1e-5 (see csrc/trtllm_mnnvl_allreduce.cu line 91). This inconsistency will cause test failures or require overly loose tolerances.

Since flashinfer/comm/trtllm_mnnvl_ar.py also incorrectly defaults to torch.finfo(input.dtype).eps at line 332, these two issues are related. Once the main API is fixed to use 1e-5, update this test accordingly:

         residual_out = allreduce_result + residual  # Add residual
         norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
+            residual_out, norm_weight, 1e-5, enable_pdl=False
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, 1e-5, enable_pdl=False
)
🤖 Prompt for AI Agents
In tests/comm/test_trtllm_mnnvl_allreduce.py around lines 258 to 264, the test
uses torch.finfo(dtype).eps as the epsilon for the reference RMSNorm but the
kernel defaults to 1e-5; change the test to pass epsilon=1e-5 to rmsnorm so the
reference matches the kernel default (and after fixing
flashinfer/comm/trtllm_mnnvl_ar.py to default to 1e-5, ensure this test
continues to use that same 1e-5 constant for consistency).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)

566-664: Close remaining POSIX FDs in IPC path to avoid leaks

In the POSIX handle path of _alloc_mn_mcast_mem, a few FDs are still never closed:

  • local_shareable_uc_handle returned by cuMemExportToShareableHandle (line 958) is used in the IPC ring allgather but never closed.
  • During the ring, each rank sends its local_shareable_uc_handle to all peers, including itself. The self‑recv for p == group_rank populates all_shareable_uc_handles[self.group_rank], but that FD is never imported (due to if p != self.group_rank) and also never closed.

You already close imported POSIX FDs after cuMemImportFromShareableHandle and close the multicast FD after import; closing the remaining two FDs will complete the cleanup and prevent per‑allocation FD leaks in long‑running jobs.

One way to fix this:

        if (
            self._shareable_handle_type
            == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
        ):
            # All-gather fabric handles
            all_shareable_uc_handles = self.comm_backend.allgather(
                local_shareable_uc_handle.data
            )
        else:
            # Implement the allgather logic with ipc socket
            all_shareable_uc_handles = [None] * self.group_size
            for i in range(self.group_size):
                self.comm_backend.barrier()
                # Send to peer at offset i
                dest_rank = (self.group_rank + i) % self.group_size
                self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank)
                # Receive from peer at offset -i
                src_rank = (self.group_rank + self.group_size - i) % self.group_size
                all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd()
+           if (
+               self._shareable_handle_type
+               == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
+           ):
+               # Close our exported UC handle FD and the self-received FD
+               os.close(local_shareable_uc_handle)
+               if all_shareable_uc_handles[self.group_rank] is not None:
+                   os.close(all_shareable_uc_handles[self.group_rank])

The existing per‑peer close after import:

if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR:
    os.close(all_shareable_uc_handles[p])

and the multicast close:

if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR:
    os.close(shareable_mc_handle)

can stay as‑is.

Together with the new __del__ logic calling self._ipc_socket.close(), this fully addresses the descriptor‑leak concern in the IPC path.

Also applies to: 957-1005, 1008-1055

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

233-271: Align reference RMSNorm epsilon with kernel default (still using torch.finfo(dtype).eps)

prepare_test_data still uses torch.finfo(dtype).eps as the epsilon for the reference RMSNorm:

norm_out = rmsnorm(
    residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

while the kernel and test harness default to eps = 1e-5 (see run_mnnvl_ar_full and the C++ FFI wrapper’s params.epsilon default). This inconsistency can mask subtle discrepancies behind loose tolerances or cause avoidable test drift.

To keep the reference path exactly aligned with the implementation, switch this to the same constant:

-        norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
-        )
+        norm_out = rmsnorm(
+            residual_out,
+            norm_weight,
+            1e-5,
+            enable_pdl=False,
+        )

(or better, reuse the same eps value passed into run_mnnvl_ar_full to avoid hard‑coding the constant twice).

🧹 Nitpick comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)

100-114: Ensure epsilon defaults stay consistent with Python API and tests

Here params.epsilon falls back to 1e-5 when the Optional epsilon is not provided:

params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;

The Python wrapper in flashinfer/comm/trtllm_mnnvl_ar.py and the tests in tests/comm/test_trtllm_mnnvl_allreduce.py should use the same default to avoid silent discrepancies between the kernel and reference paths. The core test harness already sets eps = 1e-5; the remaining mismatch is in the reference RMSNorm computation (see prepare_test_data), which still uses torch.finfo(dtype).eps.

flashinfer/comm/mnnvl.py (3)

132-150: Fix alloc_and_copy_to_cuda return type vs None behavior

alloc_and_copy_to_cuda is annotated as returning int but still returns None for an empty host_ptr_array:

def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
    if not host_ptr_array:
        return None

Current call sites (signal_pads and uc_ptrs) always pass non‑empty lists, so behavior is correct, but the annotation is misleading and could hide bugs if the function gets reused.

Either make the return type explicit about None or enforce non‑emptiness by raising:

-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
-    if not host_ptr_array:
-        return None
+def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
+    if not host_ptr_array:
+        raise ValueError("host_ptr_array must be non-empty")

(or change the annotation to Optional[int] if you prefer the sentinel behavior).


885-893: IPC opId bootstrap looks fine; consider documenting ordering guarantees

_init_ipc_socket uses an MPI‑like bcast to distribute a randomly chosen opId from rank 0, then uses it to construct IpcSocket endpoints on all ranks. This nicely avoids hard‑coding operation IDs and lines up with the C++ IPC model.

Given the reliance on collective barriers around send_fd/recv_fd, it would help future maintainers to mention in a comment here that all ranks are expected to participate in the same sequence of IPC operations for a given opId, and that mismatched usage will deadlock. The code is correct as written; this is just a documentation/clarity suggestion.


1143-1170: McastGPUBuffer workspace integration and pointer getters look consistent

The new comm_backend_for_handle_transfer parameter is threaded through to McastDeviceMemory, and the added get_unicast_ptr wrapper simply delegates to mcast_device_memory.get_unicast_ptr(rank). This lines up with how tests and get_allreduce_mnnvl_workspace use these pointers and keeps pointer access encapsulated.

The placeholder buffer‑view methods (get_multicast_buffer, get_unicast_buffer) are clearly marked NotImplementedError, so they won’t be hit accidentally. If you plan to expose tensor views later, you can implement them via create_tensor_from_cuda_memory.

Also applies to: 1209-1212

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 92cbd48 and 5be2697.

📒 Files selected for processing (4)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/mnnvl.py (18 hunks)
  • include/flashinfer/utils.cuh (2 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/utils.cuh
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (298-388)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (226-295)
  • get_allreduce_mnnvl_workspace (395-448)
  • get_required_buffer_size_bytes (116-141)
flashinfer/comm/mnnvl.py (14)
  • barrier (168-168)
  • barrier (227-228)
  • Get_rank (156-156)
  • Get_rank (215-216)
  • Get_size (159-159)
  • Get_size (218-219)
  • bcast (165-165)
  • bcast (224-225)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1205-1207)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1213-1215)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1209-1211)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-219)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


888-888: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
include/flashinfer/utils.cuh (1)

293-307: Thread‑safe SM count cache looks good; confirm single‑GPU‑per‑process assumption

Using static std::atomic<int> with relaxed loads/stores makes this helper thread‑safe and avoids the previous static int data race. The comment explicitly assumes one CUDA device per process, since the cached sm_count is never recomputed if the current device changes.

If there are any call sites that may run in a multi‑GPU‑per‑process setup, consider extending this to a per‑device cache (e.g., keyed by device id) rather than a single global integer; otherwise, this implementation is fine as long as the single‑device assumption holds.

csrc/trtllm_mnnvl_allreduce.cu (1)

41-76: RMSNorm fusion validation and shape checks look correct

The updated precondition now correctly requires residual_in, residual_out, gamma, and epsilon when rmsnorm_fusion is true, and the subsequent shape checks on residual_in, residual_out, and gamma guard the fused path against mismatched tensors. This should prevent the fused kernels from ever seeing invalid residual/norm inputs via the FFI boundary.

The overall parameter wiring into AllReduceFusionParams (including buffer pointers and flags) also looks consistent with the Python side.

flashinfer/comm/mnnvl.py (1)

781-790: Good: IPC socket is now closed in destructor

The addition of:

if hasattr(self, "_ipc_socket"):
    self._ipc_socket.close()

inside __del__ ensures the Unix domain socket is closed and, for non‑abstract sockets, the filesystem entry is unlinked. This addresses the earlier socket‑leak concern while remaining safe when construction fails before _ipc_socket is set.

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

16-103: Test harness refactor cleanly exercises both refactored and legacy APIs

The new helpers (row_linear_residual_norm_fusion_forward, _legacy, run_mnnvl_ar_full) and parametrized tests (test_mnnvl_allreduce_refactored, test_mnnvl_allreduce_legacy) do a good job of:

  • Sharing core logic between fused and non‑fused paths.
  • Covering both the new workspace‑based API and the legacy pointer‑based API.
  • Exercising a variety of sequence lengths, dtypes, and hidden sizes.
  • Integrating MPI barriers and rank‑aware logging to make multi‑rank failures diagnosable.

Once the epsilon alignment in prepare_test_data is fixed, this test suite should give solid coverage for the new fused implementation and its backward‑compatibility guarantees.

Also applies to: 274-397, 439-465

comm_backend: Optional[CommBackend] = None,
):
"""
Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can check this?

Copy link
Contributor Author

@timlee0212 timlee0212 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to update the doc. Fixed.

def __init__(
self,
mapping: Mapping,
buffer_size_in_bytes: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide guidance for buffer_size_in_bytes? E.g., in function of number of tokens and hidden size? Or just refer to get_required_buffer_size_bytes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just refer to get_required_buffer_size_bytes

comm_backend,
)

# We use FP32 for sentinel value regardless of the real dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Before we used the real dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using LDG.128 for all data read/write, and if the allocation is 4-byte (word) aligned, reading/writing each word can be considered atomic. Thus, it is sufficient to check in FP32 granularity regardless of the dtype. This will simplify the buffer management as decoupling the sentinel value from the dtype.

Comment on lines +229 to +235
def trtllm_mnnvl_allreduce(
input: torch.Tensor,
workspace: MNNVLAllreduceFusionWorkspace,
launch_with_pdl: bool,
output: Optional[torch.Tensor] = None,
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want this check. I recently added it because it did bite others.

def __init__(
self,
mapping: Mapping,
buffer_size_in_bytes: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to replace "buffer_size_in_bytes" by the parameters that get_required_buffer_size_bytes takes, and just call this from the init function. Seems more user friendly.

If you do want to just allocate a blob of memory, we could still have buffer_size_in_bytes as an addtional parameter that would override whatever is calculated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is a good design as it might give the user an impression that the allocated workspace will ONLY support the set of parameters (max_num_tokens, hidden_dim, dtype, strategy)

But actually, the workspace usage is quite flexible and as long as the required workspace size is smaller than the allocation, it will work. Thus, the intended usage is the user checks the required workspace size (or we can check it when calling the allreduce function, but at some cost) before using it.

Optional<TensorView> out) {
cudaSetDevice(in.device().device_id);
auto stream = get_stream(in.device());
// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment removed.

AUTO = 99

@staticmethod
def is_one_shot(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to "heuristic_for_one_shot" or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

361-362: Critical: Restore epsilon default to 1e-5 to match kernel.

This epsilon fallback was flagged as critical in a previous review but remains unresolved. Using torch.finfo(input.dtype).eps sets epsilon to approximately 1e-3 for fp16, diverging from the kernel's built-in 1e-5 default (see csrc/trtllm_mnnvl_allreduce.cu line 96). This materially alters RMSNorm results and breaks compatibility with TensorRT-LLM.

Apply this fix:

-    if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+    if epsilon is None:
+        epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

118-136: Consider replacing @functools.cache on instance method.

Using @functools.cache on an instance method can prevent the instance from being garbage collected, leading to memory leaks. Since this method takes self as the first parameter, the cache will hold references to the instance.

Consider either:

  1. Making this a standalone function that takes workspace parameters explicitly
  2. Using @functools.lru_cache(maxsize=...) with a reasonable limit
  3. Implementing manual caching in the instance if needed

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5be2697 and c6ed147.

📒 Files selected for processing (2)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
  • Mapping (21-475)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1135-1215)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1115-1132)
  • lamport_initialize (1174-1175)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1213-1215)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1209-1211)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1205-1207)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (192-243)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py

77-79: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


282-284: Avoid specifying long messages outside the exception class

(TRY003)


289-291: Avoid specifying long messages outside the exception class

(TRY003)


303-305: Avoid specifying long messages outside the exception class

(TRY003)


365-367: Avoid specifying long messages outside the exception class

(TRY003)


369-371: Avoid specifying long messages outside the exception class

(TRY003)


373-375: Avoid specifying long messages outside the exception class

(TRY003)


379-381: Avoid specifying long messages outside the exception class

(TRY003)


385-387: Avoid specifying long messages outside the exception class

(TRY003)


398-400: Avoid specifying long messages outside the exception class

(TRY003)


528-530: Avoid specifying long messages outside the exception class

(TRY003)


599-601: Avoid specifying long messages outside the exception class

(TRY003)


605-607: Avoid specifying long messages outside the exception class

(TRY003)


610-612: Avoid specifying long messages outside the exception class

(TRY003)


614-616: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


624-626: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)

29-113: LGTM! Fusion entry point is well-structured.

The refactored entry point properly validates all inputs, including the RMSNorm fusion parameters that were flagged in previous reviews. The dispatch logic cleanly selects between oneshot and twoshot strategies, and error messages are clear and actionable.

flashinfer/comm/trtllm_mnnvl_ar.py (3)

30-48: Strategy enum and heuristic look good.

The MNNVLAllreduceFusionStrategy enum provides a clear interface for selecting between oneshot and twoshot approaches, with a sensible AUTO mode that uses an empirically-derived threshold.


250-326: Buffer size validation properly implemented.

The function now includes the buffer size check that was requested in previous reviews (lines 300-305), preventing potential out-of-bounds access. Input validation is comprehensive and error messages are clear.


422-646: Deprecation strategy is well-executed.

The legacy APIs are properly marked with @deprecated decorators and include clear migration guidance. The wrappers correctly redirect to the new fusion-based implementations, maintaining backward compatibility while encouraging adoption of the improved APIs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)

132-149: Fix return type inconsistency.

The function returns None at line 137 when host_ptr_array is empty, but the return type annotation at line 132 indicates int. This creates a type mismatch.

Consider one of these fixes:

Option 1: Return Optional[int] and update callers to handle None:

-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
+def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:

Option 2: Raise an error instead of returning None:

     if not host_ptr_array:
-        return None
+        raise ValueError("host_ptr_array cannot be empty")
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

377-378: Restore RMSNorm epsilon default to 1e-5.

Overriding epsilon with torch.finfo(input.dtype).eps replaces the kernel's built-in 1e-5 default (see trtllm_mnnvl_allreduce_fusion in csrc/trtllm_mnnvl_allreduce.cu line ~35: params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking numerical parity.

Apply this diff to fix:

     if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+        epsilon = 1e-5
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

134-152: Consider alternatives to @functools.cache on instance methods.

Using @functools.cache (or @lru_cache) on instance methods can prevent garbage collection of instances because the cache holds references to bound methods, which in turn hold references to self. Since MNNVLAllreduceFusionWorkspace instances are likely long-lived in typical usage, this may be acceptable, but consider these alternatives:

  • Use @functools.lru_cache(maxsize=128) to limit cache growth
  • Move caching logic to a module-level cache keyed on relevant parameters
  • Document the caching behavior and its memory implications

Based on learnings

Apply this diff if you want to limit cache size:

-    @functools.cache
+    @functools.lru_cache(maxsize=128)
     def is_buffer_size_sufficient(
flashinfer/comm/mnnvl.py (2)

640-654: Prefix unused unpacked variables with underscore.

The variables msg, flags, and addr from recvmsg are unpacked but never used. Prefix them with _ to indicate they're intentionally ignored.

Apply this diff:

-        msg, ancdata, flags, addr = self.sock.recvmsg(
+        _msg, ancdata, _flags, _addr = self.sock.recvmsg(

893-900: Consider using secrets module for opId generation.

While cryptographic randomness is not strictly required for socket naming, using secrets.randbelow(2**64) instead of random.randint provides better collision resistance if multiple jobs run concurrently on the same node.

Apply this diff:

+import secrets
+
     def _init_ipc_socket(self):
         if self.group_rank == 0:
-            # Gnerate the opId
-            opId = random.randint(0, 2**64 - 1)
+            # Generate the opId
+            opId = secrets.randbelow(2**64)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6ed147 and a390685.

📒 Files selected for processing (2)
  • flashinfer/comm/mnnvl.py (19 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1143-1224)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1123-1140)
  • lamport_initialize (1183-1184)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1222-1224)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1218-1220)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1214-1216)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


896-896: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

flashinfer/comm/trtllm_mnnvl_ar.py

77-79: Avoid specifying long messages outside the exception class

(TRY003)


134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


298-300: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


319-321: Avoid specifying long messages outside the exception class

(TRY003)


381-383: Avoid specifying long messages outside the exception class

(TRY003)


385-387: Avoid specifying long messages outside the exception class

(TRY003)


389-391: Avoid specifying long messages outside the exception class

(TRY003)


395-397: Avoid specifying long messages outside the exception class

(TRY003)


401-403: Avoid specifying long messages outside the exception class

(TRY003)


414-416: Avoid specifying long messages outside the exception class

(TRY003)


544-546: Avoid specifying long messages outside the exception class

(TRY003)


615-617: Avoid specifying long messages outside the exception class

(TRY003)


621-623: Avoid specifying long messages outside the exception class

(TRY003)


626-628: Avoid specifying long messages outside the exception class

(TRY003)


630-632: Avoid specifying long messages outside the exception class

(TRY003)


635-637: Avoid specifying long messages outside the exception class

(TRY003)


640-642: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

266-341: LGTM!

The function correctly validates inputs, selects strategy, checks buffer size sufficiency (addressing past review feedback), and invokes the fusion kernel with appropriate parameters.

flashinfer/comm/mnnvl.py (1)

788-789: LGTM!

The IPC socket cleanup correctly uses hasattr to check for existence before closing, addressing the file descriptor leak concern from past reviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants