diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 6bac5372a8..e1c998d8ea 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,77 +26,90 @@ using tvm::ffi::Optional; } \ }() -void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, - int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, - int64_t rank, bool wait_for_results, bool launch_with_pdl, - Optional out) { - cudaSetDevice(in.device().device_id); - auto stream = get_stream(in.device()); +void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, + int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, + TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, + bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, + TensorView output, Optional residual_out, + Optional residual_in, Optional gamma, + Optional epsilon) { + cudaSetDevice(input.device().device_id); + auto stream = get_stream(input.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] { // Extract parameters from tensors - int64_t num_tokens = in.size(0); - int64_t token_dim = in.size(1); + int64_t num_tokens = input.size(0); + int64_t token_dim = input.size(1); // Validate input parameters - TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) - << "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type); + TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0) + << "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type); + TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1)) + << "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << output.size(0) << ", " << output.size(1) << ")"; TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64) << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK(out.has_value() || !wait_for_results) - << "out tensor must be provided if wait_for_results is true"; + TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() && + epsilon.has_value()) || + !rmsnorm_fusion) + << "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is " + "true"; + + if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1) + << ")"; + TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && + residual_out.value().size(1) == token_dim) + << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) + << ")"; + TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) + << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" + << gamma.value().size(0) << ")"; + } // Create the parameters struct - AllReduceParams params; - params.nranks = nranks; - params.rank = rank; - params.buffer_M = buffer_M; - params.num_tokens = num_tokens; - params.token_dim = token_dim; - params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); - params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); - params.buffer_flags = buffer_flags_mnnvl.data_ptr(); - params.wait_for_results = wait_for_results; - params.launch_with_pdl = launch_with_pdl; - params.input = in.data_ptr(); - params.output = out.has_value() ? out.value().data_ptr() : nullptr; - params.stream = stream; + AllReduceFusionParams params; - auto status = twoshot_allreduce_dispatch_world_size(params); - TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); - }); -} + // Aux Information + params.nRanks = nranks; + params.rank = rank; + params.numTokens = num_tokens; + params.tokenDim = token_dim; + params.bufferPtrsDev = reinterpret_cast(buffer_ptrs_dev); + params.bufferPtrLocal = reinterpret_cast(buffer_ptr_local); + params.multicastPtr = reinterpret_cast(multicast_buffer_ptr); + params.bufferFlags = reinterpret_cast(buffer_flags_mnnvl.data_ptr()); + params.rmsNormFusion = rmsnorm_fusion; + params.launchWithPdl = launch_with_pdl; -void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, - TensorView normed_output, TensorView gamma, double epsilon, - TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { - cudaSetDevice(prenorm_output.device().device_id); - auto stream = get_stream(prenorm_output.device()); + // input data + params.input = const_cast(input.data_ptr()); + params.residualIn = + residual_in.has_value() ? const_cast(residual_in.value().data_ptr()) : nullptr; + params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; + params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] { - // Create the parameters struct - RMSNormParams params; - params.residual_output = prenorm_output.data_ptr(); - params.output = normed_output.data_ptr(); - params.input = reinterpret_cast(multicast_buffer_ptr); - params.gamma = gamma.data_ptr(); - params.epsilon = epsilon; - params.residual = residual.data_ptr(); - params.buffer_flags = reinterpret_cast(buffer_flags.data_ptr()); - params.batch = normed_output.size(0); - params.hidden_dim = normed_output.size(1); + // output data + params.output = const_cast(output.data_ptr()); + params.residualOut = + residual_out.has_value() ? const_cast(residual_out.value().data_ptr()) : nullptr; params.stream = stream; - params.launch_with_pdl = launch_with_pdl; - auto status = twoshot_rmsnorm_dispatch_hidden_dim(params); + + cudaError_t status; + if (use_oneshot) { + status = oneshotAllreduceFusionDispatch(params); + } else { + status = twoshotAllreduceFusionDispatch(params); + } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_rmsnorm_dispatch_hidden_dim failed with error code " - << cudaGetErrorString(status); + << "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status); }); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion); diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 12aec978ec..48c04e6287 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,6 +16,12 @@ import ctypes import logging import os +import socket +import array +import random + +import contextlib + from abc import ABC, abstractmethod from dataclasses import dataclass import platform @@ -123,7 +129,7 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: return False -def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ @@ -140,7 +146,7 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: ) # c_array should be freed by GC - return device_ptr + return int(device_ptr) class CommBackend(ABC): @@ -155,6 +161,12 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def bcast(self, data: Any, root: int) -> Any: ... + + @abstractmethod + def barrier(self) -> None: ... + @abstractmethod def Split(self, color: int, key: int) -> "CommBackend": ... @@ -209,6 +221,12 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def bcast(self, data: Any, root: int) -> Any: + return self._mpicomm.bcast(data, root) + + def barrier(self): + self._mpicomm.Barrier() + def Split(self, color: int, key: int) -> CommBackend: self._mpicomm = self._mpicomm.Split(color, key) return MPIBackend() # Returns new adapter @@ -545,6 +563,107 @@ def supports_mnnvl() -> bool: return support_nvlink_and_all_up +# The helper class for passing the FD handle over the socket. +class IpcSocket: + """Unix Domain Socket for IPC file descriptor passing""" + + def __init__(self, rank: int, op_id: int, use_abstract=True): + """ + Initialize IPC socket + + Args: + rank: Process rank + op_id: Unique operation ID (hash) + use_abstract: Use Linux abstract socket namespace + """ + self.rank = rank + self.op_id = op_id + self.use_abstract = use_abstract + + # Create Unix domain socket (DGRAM for compatibility with C code) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + + # Create unique socket name + socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}" + + if use_abstract: + # Linux abstract socket: prepend null byte + self.socket_path = "\0" + socket_name + else: + self.socket_path = socket_name + # Remove existing socket file if it exists + with contextlib.suppress(FileNotFoundError): + os.unlink(socket_name) + + # Bind socket + self.sock.bind(self.socket_path) + + def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None): + """ + Send a file descriptor to another process + + Args: + fd: File descriptor to send + dest_rank: Destination process rank + dest_op_id: Destination operation ID + """ + # Construct destination socket path + dest_op_id = dest_op_id or self.op_id + dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}" + + if self.use_abstract: + dest_path = "\0" + dest_socket_name + else: + dest_path = dest_socket_name + + # Prepare message with file descriptor + # Send dummy byte as data (required) + dummy_data = b"\x00" + + # Pack file descriptor in ancillary data (SCM_RIGHTS) + fds = array.array("i", [fd]) + ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())] + + # Send message with file descriptor + self.sock.sendmsg([dummy_data], ancillary, 0, dest_path) + + def recv_fd(self): + """ + Receive a file descriptor from another process + + Returns: + int: Received file descriptor + """ + # Receive message with ancillary data + # Maximum size for ancillary data containing one fd + fds = array.array("i") + msg, ancdata, flags, addr = self.sock.recvmsg( + 1, + socket.CMSG_SPACE( + fds.itemsize + ), # Buffer size for dummy data # Ancillary data size + ) + + # Extract file descriptor from ancillary data + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds = array.array("i") + fds.frombytes( + cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)] + ) + return fds[0] + + raise RuntimeError("No file descriptor received") + + def close(self): + """Close the socket""" + self.sock.close() + if not self.use_abstract and self.socket_path: + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + + +# TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -555,6 +674,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -581,6 +701,7 @@ def __init__( self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 + self.comm_backend = comm_backend_for_handle_transfer or MPIBackend() # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr @@ -593,6 +714,8 @@ def __init__( int ] = [] # std::vector mUcHandles + self._shareable_handle_type = None + # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE @@ -630,11 +753,17 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) - - self._alloc_mn_mcast_mem(buf_size) + # Use fabric handle for multi-node NVLS + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) else: - # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem - raise NotImplementedError("Single-node NVLS allocation not implemented yet") + self._init_ipc_socket() + # Use NVLink handle for single-node NVLS + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) + self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads self.signal_pads = [0] * self.group_size @@ -656,8 +785,8 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if not self.is_multi_node: - return + if hasattr(self, "_ipc_socket"): + self._ipc_socket.close() # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. @@ -753,6 +882,23 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size + def get_allocation_size(self) -> int: + """Get the total allocation size (including signal pad)""" + return self.allocation_size + + def get_usable_buffer_size(self) -> int: + """Get the usable buffer size (excluding signal pad)""" + return self.allocation_size - self.SIGNAL_PAD_SIZE + + def _init_ipc_socket(self): + if self.group_rank == 0: + # Gnerate the opId + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm_backend.bcast(opId, root=0) + self._ipc_socket = IpcSocket(self.group_rank, opId) + def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" @@ -767,14 +913,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Get MPI communicator - comm = MpiComm() - # Set up allocation properties - handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = handle_type + allocation_prop.requestedHandleTypes = self._shareable_handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() allocation_prop.location.type = ( @@ -788,7 +929,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, - cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, + cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED, ) ) @@ -801,7 +942,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = handle_type + mc_prop.handleTypes = self._shareable_handle_type # Get multicast granularity mc_granularity = checkCudaErrors( @@ -821,17 +962,34 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) - # Export local handle to fabric handle - my_fabric_handle = checkCudaErrors( + # Export local handle to fabric handle or FD + local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) - # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): + # All-gather fabric handles + all_shareable_uc_handles = self.comm_backend.allgather( + local_shareable_uc_handle.data + ) + else: + # Implement the allgather logic with ipc socket + all_shareable_uc_handles = [None] * self.group_size + for i in range(self.group_size): + self.comm_backend.barrier() + # Send to peer at offset i + dest_rank = (self.group_rank + i) % self.group_size + self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) + # Receive from peer at offset -i + src_rank = (self.group_rank + self.group_size - i) % self.group_size + all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() cuda.cuCtxSynchronize() # Import remote handles @@ -839,42 +997,70 @@ def _alloc_mn_mcast_mem(self, buf_size: int): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - all_fabric_handles[p], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + all_shareable_uc_handles[p], + self._shareable_handle_type, ) ) + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(all_shareable_uc_handles[p]) # Initialize multicasting if self.group_rank == 0: # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - # Export multicast handle - mc_fabric_handle = checkCudaErrors( + # Export multicast handle, there's only one handle for the entire group + shareable_mc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) else: - mc_fabric_handle = None - - # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) + shareable_mc_handle = None + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): + # Broadcast multicast handle + shareable_mc_handle = self.comm_backend.bcast( + shareable_mc_handle.data if shareable_mc_handle else None, root=0 + ) + else: + # Implement bcast logic with ipc socket + if self.group_rank == 0: + for p in range(1, self.group_size): + self.comm_backend.barrier() + self._ipc_socket.send_fd(shareable_mc_handle, p) + else: + # Other ranks receive from rank 0 + # We need to order the receive to avoid a race condition bug we encountered. If driver fixed this issue, the additional barriers used for ordering can be removed. + for _ in range(self.group_rank): + self.comm_backend.barrier() + shareable_mc_handle = self._ipc_socket.recv_fd() + for _ in range(self.group_size - self.group_rank - 1): + self.comm_backend.barrier() # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - mc_fabric_handle_data, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + shareable_mc_handle, + self._shareable_handle_type, ) ) - + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(shareable_mc_handle) # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) @@ -946,8 +1132,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): else: raise ValueError(f"Unsupported dtype: {dtype}") - # Calculate number of elements that fit in allocation_size - num_elements = self.allocation_size // dsize + # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. + num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize checkCudaErrors( memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) @@ -969,27 +1155,35 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ Constructor for McastGpuBuffer. Args: - buf_size: The total size of the buffer in bytes + buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements. group_size: The number of ranks in the communication group group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used + comm_backend_for_handle_transfer: The communicator to use for handle transfer """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, + group_size, + group_rank, + device.index, + mn_nvlink, + comm_backend_for_handle_transfer, ) - self.buf_size = buf_size + # Update buf_size to reflect the actual usable buffer size after allocation + self.buf_size = self.mcast_device_memory.get_usable_buffer_size() self.local_device = device def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_mc_buffer( + def get_multicast_buffer( self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 ) -> torch.Tensor: """ @@ -1003,12 +1197,28 @@ def get_mc_buffer( Returns: A PyTorch tensor wrapping the multicast buffer section """ + + # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. + raise NotImplementedError("Not implemented yet") + + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: + """ + Returns a PyTorch tensor view of the unicast buffer portion. + """ + + # TODO: How can I warp a raw pointer to a tensor in python level? raise NotImplementedError("Not implemented yet") def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() + def get_unicast_ptr(self, rank: int) -> int: + """Get the raw unicast pointer to a given rank""" + return self.mcast_device_memory.get_unicast_ptr(rank) + def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 76aedee260..afdd580910 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -5,17 +5,19 @@ import functools import math -import os +import logging from types import SimpleNamespace from typing import Optional, Tuple +from enum import Enum import torch +from typing_extensions import deprecated from flashinfer.comm.mapping import Mapping from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer +from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend def mpi_barrier(): @@ -25,104 +27,423 @@ def mpi_barrier(): MPI.COMM_WORLD.Barrier() +class MNNVLAllreduceFusionStrategy(Enum): + ONESHOT = 0 + TWOSHOT = 1 + AUTO = 99 + + @staticmethod + def select_strategy( + tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype + ) -> "MNNVLAllreduceFusionStrategy": + elem_size = torch.tensor([], dtype=dtype).element_size() + if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: + return MNNVLAllreduceFusionStrategy.ONESHOT + else: + return MNNVLAllreduceFusionStrategy.TWOSHOT + + +# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size +MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 + + +class MNNVLAllreduceFusionWorkspace: + NUM_LAMPORT_BUFFERS = 3 + + def __init__( + self, + mapping: Mapping, + buffer_size_in_bytes: Optional[int] = None, + comm_backend: Optional[CommBackend] = None, + ): + """ + Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. + + Args: + mapping: Mapping configuration containing rank info + buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. + """ + if buffer_size_in_bytes is None: + # Default to 16MB workspace size if not provided + buffer_size_in_bytes = 16 * (1024**2) + else: + # Round up to the nearest multiple of 8MB + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * ( + 8 * (1024**2) + ) + if comm_backend is None: + comm_backend = MPIBackend() + if buffer_size_in_bytes > (2**32 - 1): + raise ValueError( + f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." + ) + + # Calculate total requested workspace size + requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + + self.rank = mapping.tp_rank + self.tp_size = mapping.tp_size + logging.debug( + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer." + ) + + # Allocate the workspace + self.mcast_buffer_handle = McastGPUBuffer( + requested_workspace_size, + mapping.tp_size, + mapping.tp_rank, + torch.device("cuda", mapping.local_rank), + mapping.is_multi_node(), + comm_backend, + ) + + # Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer) + allocated_size = self.mcast_buffer_handle.buf_size + # We want the buffer size to be aligned to 16B which is the granularity for buffer management. + self.buffer_size_bytes = ( + math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 + ) + # This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size. + self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS + + logging.debug( + f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes." + ) + + # We use FP32 for sentinel value regardless of the real dtype + self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) + # Wait until the initialization is done + torch.cuda.synchronize() + comm_backend.barrier() + + # This is a buffer to maintain the state of this allreduce Op + # Should have the same lifetime with self._buffer + # The flag should be binded to each buffer allocation + # Layout: [cur idx, dirty idx, bytes per buffer, dirty num stages, numBytesToClear[4], access count ptr] + num_bytes_to_clear = [0] * 4 + self.buffer_flags = torch.tensor( + [0, 2, self.buffer_size_bytes, 0, *num_bytes_to_clear, 0], + dtype=torch.uint32, + device=torch.device("cuda", mapping.local_rank), + ) + + self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() + self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) + self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + + @functools.cache + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> bool: + """ + Calculate the required buffer size for a given problem size. + """ + required_buffer_size = self.get_required_buffer_size_bytes( + tp_size, num_tokens, hidden_dim, dtype, strategy + ) + if required_buffer_size > self.buffer_size_bytes: + return False + else: + return True + + @staticmethod + @functools.cache + def get_required_buffer_size_bytes( + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> int: + """ + Calculate the required buffer size for a given problem size. + """ + elem_size = torch.tensor([], dtype=dtype).element_size() + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + tp_size, num_tokens, hidden_dim, dtype + ) + + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: + # For one-shot, each rank needs to store num_tokens * tp_size tokens + buffer_size = num_tokens * hidden_dim * tp_size * elem_size + else: + # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. + # 2 Stage is required for the two-shot allreduce. + buffer_size = ( + 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + ) + return buffer_size + + @functools.cache def get_trtllm_mnnvl_comm_module(): module = gen_trtllm_mnnvl_comm_module().build_and_load() @register_custom_op( - "flashinfer::trtllm_mnnvl_all_reduce", + "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ - "inp", + "input", "multicast_buffer_ptr", "buffer_ptrs_dev", - "buffer_mnnvl", + "buffer_ptr_local", "buffer_flags_mnnvl", "nranks", "rank", - "wait_for_results", + "rmsnorm_fusion", "launch_with_pdl", - "out", + "use_oneshot", + "output", + "residual_out", + "residual_in", + "gamma", + "epsilon", ], ) - def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, + def trtllm_mnnvl_allreduce_fusion( + input: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer - buffer_mnnvl: torch.Tensor, + buffer_ptr_local: int, # Pointer address as integer buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, - wait_for_results: bool, + rmsnorm_fusion: bool, launch_with_pdl: bool, - out: Optional[torch.Tensor], + use_oneshot: bool, + output: torch.Tensor, + residual_out: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + epsilon: Optional[float], ) -> None: - module.trtllm_mnnvl_all_reduce( - inp, + """ + Perform a multi-node NVLink all-reduce operation with fusion. + Args: + input: Input tensor + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer + buffer_ptr_local: Pointer to local buffer as an integer + buffer_flags_mnnvl: Buffer flags tensor for synchronization + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + rmsnorm_fusion: Whether to perform RMSNorm fusion + launch_with_pdl: Whether to launch with PDL + use_oneshot: Whether to use one-shot (true) or two-shot (false) + output: Output tensor + residual_out: Residual output tensor (if rmsnorm) + gamma: Gamma tensor (if rmsnorm) + epsilon: Epsilon value (if rmsnorm) + """ + module.trtllm_mnnvl_allreduce_fusion( + input, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_mnnvl, + buffer_ptr_local, buffer_flags_mnnvl, nranks, rank, - wait_for_results, + rmsnorm_fusion, launch_with_pdl, - out, + use_oneshot, + output, + residual_out, + residual_in, + gamma, + epsilon, ) - @register_custom_op( - "flashinfer::trtllm_mnnvl_rmsnorm", - mutates_args=[ - "mcast_buffer_input", - "prenorm_output", - "normed_output", - "gamma", - "epsilon", - "residual", - "buffer_flags", - "launch_with_pdl", - ], + return SimpleNamespace( + trtllm_mnnvl_allreduce_fusion=trtllm_mnnvl_allreduce_fusion, ) - def trtllm_mnnvl_rmsnorm( - mcast_buffer_input: int, - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - buffer_flags: torch.Tensor, - launch_with_pdl: bool, - ) -> None: - """Performs MNNVL TwoShot RMSNorm on the communication buffer. - Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - mcast_buffer_input: Input tensor - gamma: The gamma parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add - buffer_flags: Buffer flags for synchronization - launch_with_pdl: Whether to launch with PDL - """ - return module.trtllm_mnnvl_rmsnorm( - mcast_buffer_input, - prenorm_output, - normed_output, - gamma, - epsilon, - residual, - buffer_flags, - launch_with_pdl, + +def trtllm_mnnvl_allreduce( + input: torch.Tensor, + workspace: MNNVLAllreduceFusionWorkspace, + launch_with_pdl: bool, + output: Optional[torch.Tensor] = None, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, +) -> torch.Tensor: + """Perform a multi-node NVLink all-reduce operation across multiple GPUs. + + This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) + technology to efficiently combine tensors across multiple GPUs and nodes. + + There are 2 variants: One-shot and Two-shot: + - One-shot: Each rank stores local shard to all other ranks. Each ranks will receive all shards at the end of the communication round and perfom local reduction. Suitable for small data size and is optimized for low latency. + - Two-shot: There will be 3 steps: + 1. Scatter each GPU's input shard to other ranks. Each rank will received all shards of a slice of tokens. + 2. Each rank perform reduction on the local tokens. + 3. Each rank broadcast the result to all ranks. + Suitable for large data size and is optimized for balancing throughput and latency. + + Args: + input: Local Input Shard [num_tokens, hidden_dim] + workspace: MNNVLAllreduceFusionWorkspace + launch_with_pdl: Whether to launch with PDL + output: Output tensor to store the result, empty tensor will be created if not provided. + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Reduced tensor [num_tokens, hidden_dim] + """ + + # Check ndims here as the shape check is done in the kernel launch code. + if len(input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." ) - return SimpleNamespace( - trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, - trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm, + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + + module.trtllm_mnnvl_allreduce_fusion( + input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + False, # No RMSNorm Fusion + launch_with_pdl, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, + output, + None, + None, + None, + None, + ) + + return output + + +def trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input: torch.Tensor, + residual_in: torch.Tensor, + gamma: torch.Tensor, + workspace: MNNVLAllreduceFusionWorkspace, + epsilon: Optional[float] = None, + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + launch_with_pdl: bool = False, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Performs MNNVL Allreduce + Residual + RMSNorm. + + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. + After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + Note: multicast buffer is the same as the unicast buffer for the current rank. + + Args: + input: Input tensor [num_tokens, hidden_dim] + residual_in: Residual input tensor [num_tokens, hidden_dim] + gamma: Gamma tensor [hidden_dim] + workspace: MNNVLAllreduceFusionWorkspace + epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. + output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. + residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. + launch_with_pdl: Whether to launch with PDL + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + + Returns: + output: Add-residual and normalized tensor [num_tokens, hidden_dim] + residual_out: Add-residual tensor [num_tokens, hidden_dim] + """ + + if epsilon is None: + epsilon = torch.finfo(input.dtype).eps + + if len(input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) + if len(residual_in.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." + ) + if gamma.numel() != input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements." + ) + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) + if residual_out is None: + residual_out = torch.empty_like(residual_in) + elif len(residual_out.shape) != 2: + raise ValueError( + f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + + module.trtllm_mnnvl_allreduce_fusion( + input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + True, # RMSNorm Fusion + launch_with_pdl, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, + output, + residual_out, + residual_in, + gamma, + epsilon, ) + return output, residual_out +# Legacy API that has been deprecated; Left for backward compatibility +@deprecated( + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" +) def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, + buffer_size_in_bytes: Optional[int] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -146,8 +467,6 @@ def get_allreduce_mnnvl_workspace( - torch.Tensor: Buffer flags tensor tracking state - int: Maximum number of elements that can fit in buffer """ - force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 @@ -159,31 +478,16 @@ def get_allreduce_mnnvl_workspace( buffer_size_in_bytes = math.ceil( TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) ) * (lcm_hidden_dim * stride) - max_num_elements = buffer_size_in_bytes // stride - - mcast_buffer = McastGPUBuffer( - buffer_size_in_bytes, - mapping.tp_size, - mapping.tp_rank, - torch.device("cuda", mapping.local_rank), - mapping.is_multi_node() or force_mn, - ) - - # Initialize the unicast buffer with -0.0 - mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) - # CPU barrier since we assume this should not be called in cuda graph - torch.cuda.synchronize() - mpi_barrier() - - # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] - buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", mapping.local_rank), + # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. + workspace = MNNVLAllreduceFusionWorkspace( + mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer ) + mcast_buffer = workspace.mcast_buffer_handle + buffer_flags = workspace.buffer_flags + max_num_elements = workspace.buffer_size_bytes // stride + return ( mcast_buffer, buffer_flags, @@ -191,6 +495,9 @@ def get_allreduce_mnnvl_workspace( ) +@deprecated( + "trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future." +) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer @@ -232,26 +539,39 @@ def trtllm_mnnvl_all_reduce( f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." ) + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: raise ValueError( f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." ) + # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) module = get_trtllm_mnnvl_comm_module() - module.trtllm_mnnvl_all_reduce( + module.trtllm_mnnvl_allreduce_fusion( inp, multicast_buffer_ptr, - int(buffer_ptrs_dev), - buffer_M, + buffer_ptrs_dev, + 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. buffer_flags_mnnvl, nranks, rank, - wait_for_results, + False, # No RMSNorm Fusion launch_with_pdl, + False, # Use two-shot out, + None, + None, + None, + None, ) +@deprecated( + "trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future." +) def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: torch.Tensor, normed_output: torch.Tensor, @@ -291,30 +611,52 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( launch_with_pdl: Whether to launch with PDL """ - # allreduce_result = Σ(shard_input across all ranks) - trtllm_mnnvl_all_reduce( + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if shard_input.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + if len(residual.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}." + ) + if gamma.numel() != shard_input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements." + ) + + if len(normed_output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}." + ) + + if len(prenorm_output.shape) != 2: + raise ValueError( + f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + module.trtllm_mnnvl_allreduce_fusion( shard_input, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_M, + unicast_ptr, buffer_flags_mnnvl, nranks, rank, - False, # No need to wait to write AR results here as we are not writing them + True, # RMSNorm Fusion launch_with_pdl, - None, # out parameter - None since wait_for_results=False - ) - - # prenorm_output = AllReduce(shard_input) + residual - # rms = sqrt(mean(prenorm_output²) + epsilon) - # normed_output = (prenorm_output / rms) * gamma - get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( - unicast_ptr, - prenorm_output, + False, normed_output, + prenorm_output, + residual, gamma, epsilon, - residual, - buffer_flags_mnnvl, - launch_with_pdl, ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 3dbed4b649..2177cfc618 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -18,52 +18,54 @@ #include #include #include +#include #include +#include #include "../exception.h" #include "../logging.h" +#include "../utils.cuh" namespace flashinfer { namespace trtllm_mnnvl_allreduce { -template -struct AllReduceParams { - int nranks; +struct AllReduceFusionParams { + int nRanks; int rank; - int buffer_M; - int num_tokens; - int token_dim; - void** buffer_ptrs_dev; - void* multicast_ptr; - void* buffer_flags; - bool wait_for_results; - bool launch_with_pdl; - - void* input; - void* output; - cudaStream_t stream; -}; + int numTokens; + int tokenDim; + void** bufferPtrsDev; + void* bufferPtrLocal; + void* multicastPtr; + uint32_t* bufferFlags; + bool rmsNormFusion; + bool launchWithPdl; -template -struct RMSNormParams { - void* residual_output; - void* output; void const* input; + void const* residualIn; void const* gamma; double epsilon; - void* residual; - uint32_t* buffer_flags; - int batch; - int hidden_dim; - cudaStream_t stream; - bool launch_with_pdl; + + void* residualOut; + void* output; + cudaStream_t stream = nullptr; }; -__device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } +namespace utils { + +constexpr uint16_t kNEGZERO_FP16 = 0x8000U; + +template +union Fp16BitCast { + T mFp; + uint16_t mInt; + + constexpr Fp16BitCast() : mInt(0) {} -__device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } + constexpr Fp16BitCast(T val) : mFp(val) {} -__device__ bool isNegZero(__nv_half val) { return isNegZero(__half2float(val)); } + constexpr Fp16BitCast(uint16_t val) : mInt(val) {} +}; template inline __device__ float toFloat(T val) { @@ -74,7 +76,6 @@ template <> inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } - template <> inline __device__ float toFloat<__nv_half>(__nv_half val) { return __half2float(val); @@ -95,581 +96,1126 @@ inline __device__ __nv_half fromFloat<__nv_half>(float val) { return __float2half(val); } -inline __device__ float2 loadfloat2(void const* ptr) { - float2 return_value; - asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" - : "=f"(return_value.x), "=f"(return_value.y) - : "l"(ptr)); - return return_value; +template +static constexpr __device__ __host__ T negZero() { + if constexpr (std::is_same_v) { + return -0.0F; + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(kNEGZERO_FP16).mFp; + } else { + static_assert(sizeof(T) == 0, "negativeZero not specialized for this type"); + } + return T{}; // Never reached, but needed for compilation } template -inline __device__ T divUp(T val, T divisor) { - return (val + divisor - 1) / divisor; +static inline __device__ bool isNegZero(T val) { + if constexpr (std::is_same_v) { + return val == 0.F && signbit(val); + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(val).mInt == kNEGZERO_FP16; + } else { + static_assert(sizeof(T) == 0, "isNegZero not specialized for this type"); + } + return false; // Never reached, but needed for compilation } -__device__ struct __attribute__((aligned(32))) LamportFlags { - uint32_t buffer_size; - uint32_t input_offset; - uint32_t clear_offset; - uint32_t num_tokens_prev; - uint32_t* offset_access_ptr; - uint32_t* buffer_flags; - - __device__ explicit LamportFlags(uint32_t* buffer_flags) - : offset_access_ptr(&buffer_flags[4]), buffer_flags(buffer_flags) { - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; - input_offset = flag.x * (buffer_size << 1U); - clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; - } - - __device__ void cta_arrive() { - __syncthreads(); - if (threadIdx.x == 0) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) - : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#else - atomicAdd(offset_access_ptr, 1); -#endif - } - } +template +constexpr __device__ __host__ PackedType getPackedLamportInit() { + static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size"); + constexpr int kNumElements = sizeof(PackedType) / sizeof(T); - __device__ void wait_and_update(uint32_t num_tokens) { - if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) { - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + union PackedT { + PackedType mPacked; + std::array mElements; + + constexpr PackedT() : mElements{} { + for (int i = 0; i < kNumElements; i++) { + mElements[i] = negZero(); } - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; - *(offset_access_ptr) = 0; } + }; + + PackedT initValue{}; + return initValue.mPacked; +} + +// A helper class to get the correct base pointer for a given layout +struct LamportBufferLayout { + uint32_t numStages = 1; + uint32_t bytesPerBuffer = 0; + static constexpr uint32_t sNumLamportBuffers = 3; + + // Implicitly inlined + [[nodiscard]] __device__ __host__ size_t getTotalBytes() const { + return numStages * static_cast(bytesPerBuffer / numStages) * sNumLamportBuffers; + } + + // Implicitly inlined + [[nodiscard]] __device__ __host__ void* getStagePtr(void* bufferBasePtr, uint32_t lamportIndex, + uint32_t stageIndex) const { + // Typecast to avoid warnings + return reinterpret_cast( + reinterpret_cast(bufferBasePtr) + + static_cast((lamportIndex * numStages + stageIndex) * + static_cast(bytesPerBuffer / numStages))); } }; +// Current Index +// Dirty Index +// bytes_per_buffer +// Dirty num_stages +// Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages +// offset_access_ptr -template -__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, - int num_tokens, int buffer_M, int token_dim, int rank, - uint32_t* buffer_flags, bool wait_for_results) { - int elt = blockIdx.y * blockDim.x + threadIdx.x; +namespace cg = cooperative_groups; - if (elt >= token_dim) return; - int token = blockIdx.x; +// PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64) +template +__device__ struct __attribute__((aligned(32))) LamportFlags { + public: + __device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1) + : mBufferFlagsPtr(bufferFlags), mFlagAccessPtr(&bufferFlags[8]) { + mCurBufferLayout.numStages = numStages; + uint4 flag = reinterpret_cast(bufferFlags)[0]; + mCurrentIndex = flag.x; + mDirtyIndex = flag.y; + // Buffer size is unchanged as the flag should be coupled to each buffer + mCurBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.numStages = flag.w; + *reinterpret_cast(&mBytesToClear) = reinterpret_cast(bufferFlags)[1]; + } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaGridDependencySynchronize(); -#endif + // Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx + [[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const { + return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx); + } - LamportFlags flags(buffer_flags); - - // Capture the number of tokens in previous iteration so that we can properly clear the buffer - // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up - uint32_t clr_toks_cta = - divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, - WORLD_SIZE) * - WORLD_SIZE; - clr_toks_cta = divUp(clr_toks_cta, gridDim.x); - - if (elt < token_dim) { - // Scatter token - int dest_rank = token % WORLD_SIZE; - int dest_token_offset = token / WORLD_SIZE; - T val = shard_ptr[token * token_dim + elt]; - if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + - rank * token_dim + elt] = val; - - // Clear the buffer used by the previous call. Note the number of tokens to clear could be - // larger than the - // number of tokens in the current call. - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + // Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, + // -1 to clear all + // FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the + // correctness? CAUTION: This function requires all threads in the grid to participate and ASSUME + // 1D thread block layout! + __device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1) { + // Rasterize the threads to 1D for flexible clearing + + uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y; + uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x; + uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x; + + if (stageIdx == -1) { + // Clear all stages + for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i); } + } else if (stageIdx < mDirtyBufferLayout.numStages) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, + stageIdx); } + } - // Reduce and broadcast - if ((token % WORLD_SIZE) == rank) { - int local_token = token / WORLD_SIZE; - float accum = 0.f; - - T values[WORLD_SIZE]; - - while (1) { - bool valid = true; - for (int r = 0; r < WORLD_SIZE; r++) { - T volatile* lamport_ptr = - (T volatile*)&input_ptrs[rank] - [flags.input_offset + local_token * token_dim * WORLD_SIZE + - r * token_dim + elt]; - values[r] = *lamport_ptr; - valid &= !isNegZero(values[r]); - } - if (valid) break; - } - for (int r = 0; r < WORLD_SIZE; r++) { - accum += toFloat(values[r]); - } - mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = - fromFloat(accum); + __device__ void ctaArrive() { + int tid{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + cg::cluster_group cluster = cg::this_cluster(); + // We update the atomic counter per cluster + tid = cluster.thread_rank(); + cluster.sync(); +#else + tid = threadIdx.x; + __syncthreads(); +#endif + if (tid == 0) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#else + atomicAdd(mFlagAccessPtr, 1); +#endif } } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaTriggerProgrammaticLaunchCompletion(); + __device__ void waitAndUpdate(uint4 bytesToClearPerStage) { + bool isLastCtaT0{false}; + int targetCount{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cg::grid_group grid = cg::this_grid(); + // Use the first thread instead of the last thread as the last thread may exit early + isLastCtaT0 = grid.thread_rank() == 0; + targetCount = grid.num_clusters(); +#else + isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0; + targetCount = gridDim.x * gridDim.y; #endif + if (isLastCtaT0) { + uint4* flagPtr = reinterpret_cast(mBufferFlagsPtr); + while (*reinterpret_cast(mFlagAccessPtr) < targetCount) { + } + // 'Current' becomes 'Dirty' + flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index + mCurrentIndex, // Dirty index + mCurBufferLayout.bytesPerBuffer, // Buffer size + mCurBufferLayout.numStages}; // Dirty - Number of stages + flagPtr[1] = bytesToClearPerStage; + *mFlagAccessPtr = 0; + } + } - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + - elt] = fromFloat(-0.f); + private: + uint32_t* mBufferFlagsPtr; + uint32_t* mFlagAccessPtr; + + uint32_t mCurrentIndex, mDirtyIndex; + // So that we can access it with uint4 + alignas(16) std::array mBytesToClear; + LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout; + + inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, + uint32_t numThreads, uint32_t bytesToClear, + uint8_t dirtyIndex, uint8_t stageIdx) { + // Round up to the float4 boundary + uint32_t clearBoundary = ceil_div(bytesToClear, sizeof(PackedType)); + for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads) { + reinterpret_cast( + mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx] = + getPackedLamportInit(); } } +}; - // Optionally wait for results if the next layer isn't doing the Lamport check - if (wait_for_results) { - // Update the atomic counter to indicate the block has read the offsets - flags.cta_arrive(); - // Only use a set of CTAs for lamport sync, reargange the grid - constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); - // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) - if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = - blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = - (void*)&input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be - // aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*)&val)) { - val = loadfloat2(lamport_ptr); - } - if (output_ptr) { - *((float2*)&output_ptr[current_pos]) = val; - } +template +union PackedVec { + PackedType packed; + T elements[sizeof(PackedType) / sizeof(T)]; + + __device__ PackedVec& operator+=(PackedVec& other) { +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + elements[i] += other.elements[i]; } + return *this; + } - // Update the buffer flags - flags.wait_and_update(num_tokens); + __device__ PackedVec operator+(PackedVec& other) { + PackedVec result; +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + result.elements[i] = elements[i] + other.elements[i]; + } + return result; } +}; + +template +inline __device__ PackedType loadPacked(T* ptr) { + return *reinterpret_cast(ptr); } -// Template-based dispatch functions following the same pattern as trtllm_allreduce.cuh -template -cudaError_t twoshot_allreduce_dispatch(AllReduceParams& params) { - int const num_threads = 128; - int const num_blocks = (params.token_dim + num_threads - 1) / num_threads; - - dim3 grid(params.num_tokens, num_blocks); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.dynamicSmemBytes = 0; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = num_threads; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; +template +inline __device__ const PackedType loadPacked(T const* ptr) { + return *reinterpret_cast(ptr); +} - cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.buffer_ptrs_dev), - reinterpret_cast(params.multicast_ptr), params.num_tokens, params.buffer_M, - params.token_dim, params.rank, - reinterpret_cast(params.buffer_flags), params.wait_for_results); +template +inline __device__ PackedType loadPackedVolatile(void const* ptr) { + static_assert(sizeof(PackedType) == 0, "Not implemented"); + return PackedType{}; +} - return cudaSuccess; +template <> +inline __device__ float4 loadPackedVolatile(void const* ptr) { + float4 returnValue; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(returnValue.x), "=f"(returnValue.y), "=f"(returnValue.z), "=f"(returnValue.w) + : "l"(ptr)); + return returnValue; } -template -cudaError_t twoshot_allreduce_dispatch_world_size(AllReduceParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_allreduce_dispatch_world_size"); - switch (params.nranks) { - case 2: - return twoshot_allreduce_dispatch(params); - case 4: - return twoshot_allreduce_dispatch(params); - case 8: - return twoshot_allreduce_dispatch(params); - case 16: - return twoshot_allreduce_dispatch(params); - case 32: - return twoshot_allreduce_dispatch(params); - case 64: - return twoshot_allreduce_dispatch(params); - default: - FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + - ". Supported sizes: {2, 4, 8, 16, 32, 64}"); - return cudaErrorInvalidValue; - } +template <> +inline __device__ float2 loadPackedVolatile(void const* ptr) { + float2 returnValue; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" + : "=f"(returnValue.x), "=f"(returnValue.y) + : "l"(ptr)); + return returnValue; } template -__device__ void copy_f4(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4 const*)src; +inline __device__ void copyF4(T_IN* dst, T_IN const* src) { + float4* dst4 = reinterpret_cast(dst); + float4 const* src4 = reinterpret_cast(src); __pipeline_memcpy_async(dst4, src4, sizeof(float4)); } -template -__device__ void copy_f4_ldg(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4*)src; - *dst4 = *src4; -} +uint32_t constexpr kWARP_SIZE = 32U; +uint32_t constexpr kLOG2_WARP_SIZE = 5U; +uint32_t constexpr kLANE_ID_MASK = 0x1f; +uint32_t constexpr kFINAL_MASK = 0xffffffff; -__device__ float4 loadfloat4(void const* ptr) { - // Check alignment - ptr should be 16-byte aligned for safe float4 load - if (reinterpret_cast(ptr) % 16 != 0) { - // Fall back to scalar loads if not aligned - float4 return_value; - float const* float_ptr = reinterpret_cast(ptr); - return_value.x = float_ptr[0]; - return_value.y = float_ptr[1]; - return_value.z = float_ptr[2]; - return_value.w = float_ptr[3]; - return return_value; +template +inline __device__ T warpReduceSumFull(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(kFINAL_MASK, val, mask, kWARP_SIZE); } - - float4 return_value; - - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), - "=f"(return_value.w) - : "l"(ptr)); - - return return_value; + return val; } -// Safer version that checks bounds before loading template -__device__ float4 loadfloat4_safe(T const* ptr, int remaining_elements) { - float return_value[4] = {0.0f, 0.0f, 0.0f, 0.0f}; +inline __device__ T warpReduceSumPartial(T val) { + int laneId = threadIdx.x & kLANE_ID_MASK; + // We make sure only the last warp will call this function + int warpSize = blockDim.x - (threadIdx.x & ~(kWARP_SIZE - 1)); + unsigned int active_mask = (1U << warpSize) - 1; - if (remaining_elements <= 0) { - return *(float4*)return_value; +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + int targetLane = laneId ^ mask; + auto tmp = __shfl_xor_sync(active_mask, val, mask, kWARP_SIZE); + val += targetLane < warpSize ? tmp : 0; } + return val; +} - // Check alignment - ptr should be 16-byte aligned for safe float4 load - bool is_aligned = (reinterpret_cast(ptr) % 16 == 0); - - if (is_aligned && remaining_elements >= 4) { - // Safe to do vectorized load - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), - "=f"(return_value[3]) - : "l"(ptr)); - } else { - // Fall back to scalar loads with bounds checking - float const* float_ptr = reinterpret_cast(ptr); - for (int i = 0; i < 4 && i < remaining_elements; i++) { - return_value[i] = toFloat(float_ptr[i]); - } +// SYNC: +// - True: share the sum across all threads +// - False: only thread 0 get the sum; Other thread's value is undefined. +template +inline __device__ T blockReduceSumPartial(T val) { + __shared__ T smem[kWARP_SIZE]; + int laneId = threadIdx.x & kLANE_ID_MASK; + int warpId = threadIdx.x >> kLOG2_WARP_SIZE; + int warpNum = (blockDim.x + kWARP_SIZE - 1) >> + kLOG2_WARP_SIZE; // Ceiling division to include partial warps + + val = (warpId == warpNum - 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + if (laneId == 0) { + smem[warpId] = val; } + __syncthreads(); - return *(float4*)return_value; -} + if (warpId == 0) { + val = (laneId < warpNum) ? smem[laneId] : (T)0.f; + // Need to consider the corner case where we only have one warp and it is partial + val = (warpNum == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); -template -inline __device__ T add(T a, T b) { - return a + b; + if constexpr (SYNC) { + if (laneId == 0) { + smem[warpId] = val; + } + } + } + if constexpr (SYNC) { + __syncthreads(); + val = smem[0]; + } + return val; } -#define FINAL_MASK 0xffffffff - template -__inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, - 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} +inline __device__ T blockReduceSumFull(T val) { + __shared__ T smem[kWARP_SIZE]; + int lane_id = threadIdx.x & kLANE_ID_MASK; + int warp_id = threadIdx.x >> kLOG2_WARP_SIZE; + int warp_num = blockDim.x >> kLOG2_WARP_SIZE; -inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; - val = warpReduceSum(val); + val = warpReduceSumFull(val); if (lane_id == 0) { smem[warp_id] = val; } __syncthreads(); - val = lane_id < warp_num ? smem[lane_id] : 0.f; - val = warpReduceSum(val); + + val = (lane_id < warp_num) ? smem[lane_id] : (T)0.f; + val = warpReduceSumFull(val); + return val; } -template -__global__ void __launch_bounds__(128, 1) - RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, - T_IN const* gamma, float epsilon, T_IN const* residual, int batch_size, - uint32_t* buffer_flags) { +template +inline __device__ T blockReduceSum(T val) { + bool hasPartialWarp = (blockDim.x & kLANE_ID_MASK) != 0; + if (hasPartialWarp) { + return blockReduceSumPartial(val); + } else { + return blockReduceSumFull(val); + } +} +// A helper function to tune the grid configuration for fused oneshot and rmsnorm kernels +// Return (block_size, cluster_size, loads_per_thread) +std::tuple adjustGridConfig(int numTokens, int dim, int eltsPerThread) { + // Start with preferred block_size and cluster_size #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int clusterSize = 8; +#else + int clusterSize = 1; +#endif + int blockSize = 128; + // ========================== Adjust the grid configuration ========================== + int threadsNeeded = ceil_div(dim, eltsPerThread); + int loadsPerThread = 1; - static bool const LAMPORT = true; + blockSize = ceil_div(threadsNeeded, clusterSize); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + while (threadsNeeded % clusterSize != 0 && clusterSize > 1) { + clusterSize /= 2; + } + blockSize = ceil_div(threadsNeeded, clusterSize); + while (blockSize < 128 && clusterSize >= 2) { + blockSize *= 2; + clusterSize /= 2; + } + int smCount = GetCudaMultiProcessorCount(); + while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { + blockSize *= 2; + clusterSize /= 2; + } +#endif - extern __shared__ uint8_t smem[]; + // Trying to scale up use multiple loads or CGA + while (blockSize > 1024) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (clusterSize < 8) { + clusterSize = clusterSize << 1; + } else { + break; + } +#else + if (loadsPerThread < 8) { + loadsPerThread += 1; + } else { + break; + } +#endif + blockSize = ceil_div(threadsNeeded, clusterSize * loadsPerThread); + } + return {blockSize, clusterSize, loadsPerThread}; +} +}; // namespace utils + +using utils::blockReduceSum; +using utils::fromFloat; +using utils::isNegZero; +using utils::LamportFlags; +using utils::loadPacked; +using utils::loadPackedVolatile; +using utils::PackedVec; +using utils::toFloat; + +template +__global__ void __launch_bounds__(1024) + oneshotAllreduceFusionKernel(T* outputPtr, T* prenormedPtr, T const* shardPtr, + T const* residualInPtr, T const* gammaPtr, T** inputPtrs, + T* mcastPtr, int const numTokens, int const tokenDim, + float epsilon, int const rank, uint32_t* bufferFlags) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int packedIdx = cluster.thread_rank(); + int token = blockIdx.x; + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; - int sample = blockIdx.y; + cudaGridDependencySynchronize(); +#else + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; +#endif - static int const CGA_THREADS = NUM_THREADS * 1; + // We only use 1 stage for the oneshot allreduce + LamportFlags flag(bufferFlags, 1); + T* stagePtrMcast = reinterpret_cast(flag.getCurLamportBuf(mcastPtr, 0)); + T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); - static int const ITERS = DIM / CGA_THREADS; - float r_input[ITERS]; - float r_gamma[ITERS]; + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.ctaArrive(); + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } - T_IN* sh_input = (T_IN*)&smem[0]; - T_IN* sh_residual = (T_IN*)&smem[NUM_INPUTS * NUM_THREADS * ITERS * sizeof(T_IN)]; - T_IN* sh_gamma = (T_IN*)&smem[(NUM_INPUTS + 1) * NUM_THREADS * ITERS * sizeof(T_IN)]; + // ==================== Broadcast tokens to each rank ============================= + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) val.elements[i] = fromFloat(0.f); + } - static int const ELTS_PER_THREAD = sizeof(float4) / sizeof(T_IN); + reinterpret_cast( + &stagePtrMcast[token * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = val.packed; - int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; + flag.ctaArrive(); + // ======================= Lamport Sync and clear the output buffer from previous iteration + // ============================= + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); - LamportFlags flags(buffer_flags); - T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); + } + } + if (valid) { + break; + } + } + + auto values = reinterpret_cast*>(valuesLamport); + // ======================= Reduction ============================= + float accum[kELTS_PER_THREAD]; + PackedVec packedAccum; + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] = toFloat(values[0].elements[i]); + } + +#pragma unroll + for (int r = 1; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif + if constexpr (RMSNormFusion) { + // =============================== Residual =============================== + PackedVec residualIn; + residualIn.packed = *reinterpret_cast(&residualInPtr[threadOffset]); + packedAccum += residualIn; + *reinterpret_cast(&prenormedPtr[threadOffset]) = packedAccum.packed; + // =============================== Rmsnorm ================================ + PackedVec gamma; + gamma.packed = *reinterpret_cast(&gammaPtr[packedIdx * kELTS_PER_THREAD]); + + float threadSum = 0.F; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + // FIXME: Use float square if accuracy issue + threadSum += toFloat(packedAccum.elements[i] * packedAccum.elements[i]); + } + float blockSum = blockReduceSum(threadSum); - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - offsets[i][j] = - i * batch_size * DIM + sample * DIM + blockIdx.x * DIM / 1 + k * ELTS_PER_THREAD; + __shared__ float sharedVal[8]; // Temporary variable to share the sum within block + float fullSum = blockSum; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; + } + } +#endif + float rcpRms = rsqrtf(fullSum / tokenDim + epsilon); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * + toFloat(gamma.elements[i])); } } + reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; + flag.waitAndUpdate( + {static_cast(numTokens * tokenDim * WorldSize * kELT_SIZE), 0, 0, 0}); +} -#pragma unroll - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_residual[i * ELTS_PER_THREAD], - &residual[sample * DIM + blockIdx.x * DIM + i * ELTS_PER_THREAD]); +using utils::adjustGridConfig; + +template +cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const eltsPerThread = sizeof(float4) / sizeof(T); + + auto [blockSize, clusterSize, loadsPerThread] = + adjustGridConfig(numTokens, tokenDim, eltsPerThread); + dim3 grid(numTokens, clusterSize, 1); + + FLASHINFER_CHECK(blockSize <= 1024 && loadsPerThread == 1, + "Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", + tokenDim, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + 1024 * 8 * eltsPerThread); +#else + 1024 * eltsPerThread); +#endif + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: " + "%d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, + ceil_div(tokenDim, eltsPerThread)); + + cudaLaunchAttribute attrs[2]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + attrs[1].id = cudaLaunchAttributeClusterDimension; + attrs[1].val.clusterDim.x = 1; + attrs[1].val.clusterDim.y = clusterSize; + attrs[1].val.clusterDim.z = 1; +#endif + + cudaLaunchConfig_t config{ + .gridDim = grid, + .blockDim = blockSize, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = attrs, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + .numAttrs = 2, +#else + .numAttrs = 1, +#endif + }; + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, RMSNORM) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &config, &oneshotAllreduceFusionKernel, output, residualOut, input, \ + residualIn, gamma, ucPtrs, mcPtr, numTokens, tokenDim, static_cast(params.epsilon), \ + params.rank, params.bufferFlags)); +#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + if (params.rmsNormFusion) { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true); \ + } else { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false); \ } - __pipeline_commit(); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T* residualOut = reinterpret_cast(params.residualOut); + T const* input = reinterpret_cast(params.input); + T const* residualIn = reinterpret_cast(params.residualIn); + T const* gamma = reinterpret_cast(params.gamma); -#pragma unroll - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_gamma[i * ELTS_PER_THREAD], &gamma[blockIdx.x * DIM + i * ELTS_PER_THREAD]); + switch (params.nRanks) { + // FIXME: Do we need other world sizes? + case 2: + DISPATCH_ALLREDUCE_KERNEL(2); + break; + case 4: + DISPATCH_ALLREDUCE_KERNEL(4); + break; + case 8: + DISPATCH_ALLREDUCE_KERNEL(8); + break; + case 16: + DISPATCH_ALLREDUCE_KERNEL(16); + break; + case 32: + DISPATCH_ALLREDUCE_KERNEL(32); + break; + case 64: + DISPATCH_ALLREDUCE_KERNEL(64); + break; + default: + FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nRanks) + + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); + return cudaErrorInvalidValue; } +#undef LAUNCH_ALLREDUCE_KERNEL + return cudaSuccess; +} - __pipeline_commit(); - flags.cta_arrive(); +enum MNNVLTwoShotStage : uint8_t { + SCATTER = 0, + BROADCAST = 1, + NUM_STAGES = 2, +}; - // Load all inputs - bool valid = false; +template +__global__ __launch_bounds__(128) void twoshotAllreduceKernel( + T* outputPtr, T const* shardPtr, T** inputPtrs, T* mcastPtr, uint32_t const numTokens, + uint32_t const tokenDim, uint32_t const rank, uint32_t* bufferFlags, + bool const wait_for_results) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); + + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if (!LAMPORT) cudaGridDependencySynchronize(); + int destRank = token % WorldSize; + int destTokenOffset = token / WorldSize; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); #endif + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); - while (!valid) { - valid = true; -#pragma unroll - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - - float4* dst4 = (float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - - // Calculate the absolute element offset from the start of buffer_input - int element_offset = offsets[i][j]; - - // The input pointer is already offset to: &buffer_input[buffer_offset + buffer_size] - // So the actual pointer we're accessing is: input + element_offset - // Which equals: &buffer_input[buffer_offset + buffer_size + element_offset] - - float4* src4 = (float4*)&input[element_offset]; - - float4 value; - // Check if we have enough elements remaining for a safe float4 load - if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= flags.buffer_size) { - value = loadfloat4(src4); - } else { - // Use safe load for boundary cases or out-of-bounds - int remaining_elements = flags.buffer_size - element_offset; - if (remaining_elements <= 0) { - // Completely out of bounds, return zeros - float4 return_value = {0.0f, 0.0f, 0.0f, 0.0f}; - value = return_value; - } else { - value = loadfloat4_safe(reinterpret_cast(src4), remaining_elements); - } - } + T* scatterBufLocal = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER)); + T* scatterBufDest = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[destRank], MNNVLTwoShotStage::SCATTER)); + T* broadcastBufW = + reinterpret_cast(flag.getCurLamportBuf(mcastPtr, MNNVLTwoShotStage::BROADCAST)); + T* broadcastBufR = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST)); - if (LAMPORT) { - // Assume that the 16B were written atomically, so we only need to check one value - T_IN lowest_val = *(T_IN*)&value; - valid &= !isNegZero(lowest_val); - } - *dst4 = value; - } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Make sure the clear function is called before OOB thread exits + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } + + // =============================== Scatter =============================== + + // Load vectorized data + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) { + val.elements[i] = fromFloat(0.F); } } - __syncthreads(); + // Store vectorized data + reinterpret_cast( + &scatterBufDest[destTokenOffset * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = + val.packed; - // Perform the initial input reduction - if (NUM_INPUTS > 0) { - T_IN accum[ELTS_PER_THREAD]; - float4* accum4 = (float4*)&accum; + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; + // =============================== Reduction and Broadcast =============================== - *accum4 = *(float4*)&sh_input[k * ELTS_PER_THREAD]; + if ((token % WorldSize) == rank) { + int localToken = token / WorldSize; + float accum[kELTS_PER_THREAD] = {0.F}; + + // Use float as we only check each float value for validity + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &scatterBufLocal[localToken * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); - for (int i = 1; i < NUM_INPUTS; i++) { - float4 data = *(float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - T_IN* p_d = (T_IN*)&data; - for (int x = 0; x < ELTS_PER_THREAD; x++) { - accum[x] += p_d[x]; + // Check validity across all elements +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); } } + if (valid) { + break; + } + } + + // Now we view it as the value for reduction + auto values = reinterpret_cast*>(valuesLamport); +#pragma unroll + for (int r = 0; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } - // Write back to input 0's staging location. No sync needed since all data localized to - // thread. - *(float4*)&sh_input[k * ELTS_PER_THREAD] = *accum4; + // Store vectorized result + PackedVec packedAccum; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); } + reinterpret_cast(&broadcastBufW[token * tokenDim])[packedIdx] = packedAccum.packed; } - // Wait for residual - __pipeline_wait_prior(1); - __syncthreads(); + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST); - float thread_sum = 0.f; + // Optionally wait for results if the next layer isn't doing the Lamport check + if (wait_for_results) { + // Update the atomic counter to indicate the block has read the offsets + flag.ctaArrive(); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 inp4 = - *(float4*)&sh_input[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - float4 res4 = - *(float4*)&sh_residual[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; + PackedVec valLamport; + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + while (isNegZero(valLamport.elements[0])) { + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + } + if (outputPtr) { + reinterpret_cast(&outputPtr[threadOffset])[0] = valLamport.packed; + } - T_IN* r_inp = (T_IN*)&inp4; - T_IN* r_res = (T_IN*)&res4; + // Update the buffer flags + flag.waitAndUpdate( + {static_cast(round_up(numTokens, WorldSize) * tokenDim * + kELT_SIZE), // Clear Size for scatter stage + static_cast(numTokens * tokenDim * kELT_SIZE), // Clear Size for broadcast stage + 0, 0}); + // If not wait for results, we will rely on the following kernel to update the buffer + } +} - float4 out4; +using utils::copyF4; +// This kernel works performant when loads_per_thread is 1. +// For this mode, we are able to support up to 1024 (threads) x 8 (elements) = 8192 hidden +// dimension. There are two options for further scaling up: +// 1. Use CGA if supported. It expands the hidden dimension to 8k x 8 = 64k. +// 2. Set loads_per_thread >1. Which can be used if CGA is not supported. Note that this will +// be limited by the shared memory size and register count. +template +__global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OUT* outputNorm, + T_IN* bufferInput, T_IN const* gamma, + float epsilon, T_IN const* residual, + uint32_t numTokens, uint32_t dim, + uint32_t worldSize, uint32_t* bufferFlags) { + static_assert(std::is_same_v, "T_IN and T_OUT must be the same type"); + static int const kELTS_PER_LOAD = sizeof(float4) / sizeof(T_IN); + + uint32_t const token = blockIdx.x; + uint32_t const blockSize = blockDim.x; + uint32_t const threadOffset = threadIdx.x; + + uint32_t numThreads = blockSize; + uint32_t clusterSize = 1; + uint32_t blockOffset = 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + numThreads = cluster.num_threads(); + clusterSize = cluster.num_blocks(); + blockOffset = cluster.block_rank(); +#endif + uint32_t const dimPadded = round_up(dim, kELTS_PER_LOAD * numThreads); + uint32_t const elemsPerThread = dimPadded / numThreads; + uint32_t const loadStride = blockSize; - T_IN* r_out = (T_IN*)&out4; + extern __shared__ uint8_t smem[]; + float rInput[LoadsPerThread * kELTS_PER_LOAD]; + uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD]; - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; + uint32_t const smemBufferSize = blockSize * elemsPerThread * sizeof(T_IN); + T_IN* smemInput = (T_IN*)&smem[0]; + T_IN* smemResidual = (T_IN*)&smem[smemBufferSize]; + T_IN* smemGamma = (T_IN*)&smem[2 * smemBufferSize]; - T_IN inp_plus_resid = r_inp[ii] + r_res[ii]; - r_out[ii] = inp_plus_resid; - r_input[i] = toFloat(inp_plus_resid); + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); + T_IN* input = reinterpret_cast( + flag.getCurLamportBuf(reinterpret_cast(bufferInput), MNNVLTwoShotStage::BROADCAST)); - // Accumulate the squares for RMSNorm - thread_sum += toFloat(inp_plus_resid * inp_plus_resid); - } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // The offset that current thread should load from. Note that the hidden dimension is split by CGA + // size and each block loads a contiguous chunk; The size of chunk that each block processes + uint32_t const blockChunkSize = ceil_div(dim, clusterSize * kELTS_PER_LOAD) * kELTS_PER_LOAD; + uint32_t const blockLoadOffset = token * dim + blockOffset * blockChunkSize; - *(float4*)&input_plus_residual[sample * DIM + blockIdx.x * DIM + - io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + // Each block load a contiguous chunk of tokens + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + offsets[i] = blockLoadOffset + threadLoadOffset; } - // Wait for Gamma. There will be a global synchronization as part of the reduction - __pipeline_wait_prior(0); +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemResidual[threadLoadOffset], &residual[blockLoadOffset + threadLoadOffset]); + } + } + __pipeline_commit(); +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemGamma[threadLoadOffset], &gamma[blockOffset * blockChunkSize + threadLoadOffset]); + } + } + __pipeline_commit(); - float cluster_sum = block_reduce_sum(thread_sum); + flag.ctaArrive(); + bool valid = false; + // ACQBLK if not lamport + while (!valid) { + valid = true; +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; - float rcp_rms = rsqrtf(cluster_sum / DIM + epsilon); + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + float4* dst4 = reinterpret_cast(&smemInput[threadLoadOffset]); + float4 const* src4 = reinterpret_cast(&input[offsets[i]]); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 gamma4 = - *(float4*)&sh_gamma[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - T_IN* r_g4 = (T_IN*)&gamma4; - - float4 out4; - // FIXME: this only works if T_OUT == T_IN - T_OUT* r_out = (T_OUT*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - r_gamma[i] = toFloat(r_g4[ii]); - r_out[ii] = fromFloat(r_gamma[i] * r_input[i] * rcp_rms); + float4 value = loadPackedVolatile(src4); + // Assume that the 16B were written atomically, so we only need to check one value + valid &= !isNegZero(value.x); + *dst4 = value; + } } + } + + __pipeline_wait_prior(1); + __syncthreads(); + + float threadSum = 0.f; +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + int threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec inp{.packed = loadPacked(&smemInput[threadLoadOffset])}; + PackedVec res{.packed = loadPacked(&smemResidual[threadLoadOffset])}; - *(float4*)&output_norm[sample * DIM + blockIdx.x * DIM + io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; + PackedVec inp_plus_res = inp + res; +#pragma unroll + for (int j = 0; j < kELTS_PER_LOAD; j++) { + rInput[i * kELTS_PER_LOAD + j] = toFloat(inp_plus_res.elements[j]); + threadSum += toFloat(inp_plus_res.elements[j] * inp_plus_res.elements[j]); + } + + *reinterpret_cast(&outputPreNorm[blockLoadOffset + threadLoadOffset]) = + inp_plus_res.packed; + } } - // Update the buffer pointers - flags.wait_and_update(batch_size); -#endif -} -template -cudaError_t twoshot_rmsnorm_dispatch(RMSNormParams& params) { - static constexpr int NUM_THREADS = 128; - static constexpr int CGA_THREADS = NUM_THREADS; - constexpr int iters = H_DIM / CGA_THREADS; + __pipeline_wait_prior(0); - dim3 grid(1, params.batch, 1); + float blockSum = blockReduceSum(threadSum); - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = NUM_THREADS; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; + float fullSum = blockSum; + __shared__ float sharedVal[8]; + // Use CGA Reduction if supported +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; + } + } +#endif - size_t shmem_size = 3 * NUM_THREADS * iters * sizeof(T); - config.dynamicSmemBytes = shmem_size; + float rcpRms = rsqrtf(fullSum / dim + epsilon); - cudaFuncSetAttribute(&RMSNorm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + PackedVec r_out; + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec gamma = {.packed = loadPacked(&smemGamma[threadLoadOffset])}; - cudaLaunchKernelEx( - &config, &RMSNorm, reinterpret_cast(params.residual_output), - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.gamma), static_cast(params.epsilon), - reinterpret_cast(params.residual), params.batch, params.buffer_flags); +#pragma unroll + for (uint32_t j = 0; j < kELTS_PER_LOAD; j++) { + r_out.elements[j] = fromFloat(toFloat(gamma.elements[j]) * + rInput[i * kELTS_PER_LOAD + j] * rcpRms); + } - return cudaSuccess; + *reinterpret_cast(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed; + } + } + constexpr int kELTS_SIZE = sizeof(T_IN); + + // Update the buffer pointers + flag.waitAndUpdate({static_cast(round_up(numTokens, worldSize) * dim * kELTS_SIZE), + static_cast(numTokens * dim * kELTS_SIZE), 0, 0}); } template -cudaError_t twoshot_rmsnorm_dispatch_hidden_dim(RMSNormParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_rmsnorm_dispatch_hidden_dim"); - switch (params.hidden_dim) { - case 2048: - return twoshot_rmsnorm_dispatch(params); - case 4096: - return twoshot_rmsnorm_dispatch(params); - case 5120: - return twoshot_rmsnorm_dispatch(params); // Llama-4 - case 7168: - return twoshot_rmsnorm_dispatch(params); // DeepSeek - case 8192: - return twoshot_rmsnorm_dispatch(params); +cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const numEltsPerThread = sizeof(float4) / sizeof(T); + FLASHINFER_CHECK(tokenDim % numEltsPerThread == 0, + "[MNNVL AllReduceTwoShot] token_dim must be divisible by %d", numEltsPerThread); + + int const arNumThreads = ceil_div(tokenDim, numEltsPerThread); + int const arNumBlocksPerToken = ceil_div(arNumThreads, 128); + + dim3 arGrid(numTokens, arNumBlocksPerToken); + + cudaLaunchAttribute arAttrs[1]; + arAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + arAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; + + cudaLaunchConfig_t arConfig{ + .gridDim = arGrid, + .blockDim = 128, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = arAttrs, + .numAttrs = 1, + }; + + FLASHINFER_LOG_DEBUG("[MNNVL AllReduceTwoShot] Dispatch: grid size: (%d, %d, 1), block_size: 128", + numTokens, arNumBlocksPerToken); + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &arConfig, &twoshotAllreduceKernel, output, input, ucPtrs, mcastPtr, \ + numTokens, tokenDim, params.rank, params.bufferFlags, (!params.rmsNormFusion))); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcastPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T const* input = reinterpret_cast(params.input); + switch (params.nRanks) { + case 2: + LAUNCH_ALLREDUCE_KERNEL(2); + break; + case 4: + LAUNCH_ALLREDUCE_KERNEL(4); + break; + case 8: + LAUNCH_ALLREDUCE_KERNEL(8); + break; + case 16: + LAUNCH_ALLREDUCE_KERNEL(16); + break; + case 32: + LAUNCH_ALLREDUCE_KERNEL(32); + break; + case 64: + LAUNCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL TwoShot RMSNorm: unsupported hidden_dim " + - std::to_string(params.hidden_dim) + - ". Supported sizes: {2048, 4096, 5120, 7168, 8192}"); + FLASHINFER_ERROR("[MNNVL AllReduceTwoShot] Unsupported world_size" + + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } -} +#undef LAUNCH_ALLREDUCE_KERNEL + + // Launch the rmsnorm lamport kernel if fusion is enabled + if (params.rmsNormFusion) { + auto gridConfig = adjustGridConfig(numTokens, tokenDim, numEltsPerThread); + int rnBlockSize = std::get<0>(gridConfig); + int rnClusterSize = std::get<1>(gridConfig); + int rnLoadsPerThread = std::get<2>(gridConfig); + + int rnNumThreads = rnClusterSize * rnBlockSize; + dim3 rnGrid(numTokens, rnClusterSize, 1); + cudaLaunchConfig_t rnConfig; + cudaLaunchAttribute rnAttrs[2]; + rnConfig.stream = params.stream; + rnConfig.gridDim = rnGrid; + rnConfig.blockDim = rnBlockSize; + rnConfig.attrs = rnAttrs; + rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + rnAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#ifndef DISABLE_CGA + rnAttrs[1].id = cudaLaunchAttributeClusterDimension; + rnAttrs[1].val.clusterDim.x = 1; + rnAttrs[1].val.clusterDim.y = rnClusterSize; + rnAttrs[1].val.clusterDim.z = 1; + rnConfig.numAttrs = 2; +#else + rnConfig.numAttrs = 1; +#endif + bool const rnUseCGA = rnClusterSize > 1; + int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); + int const iters = dimPadded / rnNumThreads; + + size_t const smemSize = 3 * rnBlockSize * iters * sizeof(T); + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " + "cluster_size: %d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, rnClusterSize, rnBlockSize, rnClusterSize, rnLoadsPerThread, + ceil_div(tokenDim, numEltsPerThread)); + +#define RUN_RMSNORM_KERNEL(LOADS_PER_THREAD) \ + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(&rmsNormLamport, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + smemSize)); \ + rnConfig.dynamicSmemBytes = smemSize; \ + FLASHINFER_CUDA_CALL( \ + cudaLaunchKernelEx(&rnConfig, &rmsNormLamport, residualOut, output, \ + bufferInput, gamma, static_cast(params.epsilon), residualIn, \ + numTokens, tokenDim, params.nRanks, params.bufferFlags)); + + T* residualOut = reinterpret_cast(params.residualOut); + T* output = reinterpret_cast(params.output); + T* bufferInput = reinterpret_cast(params.bufferPtrLocal); + T const* gamma = reinterpret_cast(params.gamma); + T const* residualIn = reinterpret_cast(params.residualIn); + if (rnUseCGA) { + RUN_RMSNORM_KERNEL(1); + } else { + switch (rnLoadsPerThread) { + case 1: + RUN_RMSNORM_KERNEL(1); + break; + case 2: + RUN_RMSNORM_KERNEL(2); + break; + case 3: + RUN_RMSNORM_KERNEL(3); + break; + case 4: + RUN_RMSNORM_KERNEL(4); + break; + case 5: + RUN_RMSNORM_KERNEL(5); + break; + case 6: + RUN_RMSNORM_KERNEL(6); + break; + case 7: + RUN_RMSNORM_KERNEL(7); + break; + case 8: + RUN_RMSNORM_KERNEL(8); + break; + default: + FLASHINFER_ERROR("[MNNVL AllReduceTwoShotRMSNorm] Unsupported loads_per_thread" + + std::to_string(rnLoadsPerThread) + + ". Supported sizes: {1, 2, 3, 4, 5, 6, 7, 8}"); + return cudaErrorInvalidValue; + } // switch (rnLoadsPerThread) + } // if (rnUseCGA) +#undef RUN_RMSNORM_KERNEL + + } // if (params.rmsNormFusion) + return cudaSuccess; +} } // namespace trtllm_mnnvl_allreduce } // namespace flashinfer diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5b26d7beaf..e7f9d608a5 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -289,6 +290,22 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +// This function is thread-safe and cached the sm_count. +// But it will only check the current CUDA device, thus assuming each process handles single GPU. +inline int GetCudaMultiProcessorCount() { + static std::atomic sm_count{0}; + int cached = sm_count.load(std::memory_order_relaxed); + if (cached == 0) { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_id); + cached = device_prop.multiProcessorCount; + sm_count.store(cached, std::memory_order_relaxed); + } + return cached; +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index abb3795019..cf93b1af6c 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,6 @@ # Check torch version: -from typing import Tuple +import traceback +from typing import Tuple, Optional import pytest import torch @@ -14,6 +15,95 @@ @torch.inference_mode() def row_linear_residual_norm_fusion_forward( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + mapping: Mapping, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, +): + tensor_parallel_rank = mapping.tp_rank + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + workspace, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + use_pdl = True + + if enable_fusion: + trtllm_mnnvl_ar.mpi_barrier() + + output, residual_out = ( + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input, + residual, + norm_weight, + workspace, + eps, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + ) + + return output.view(shape), residual_out.view(shape) + + else: + output = torch.empty_like(input) + + output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( + input, + workspace, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + return (output.view(shape),) + + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + +@torch.inference_mode() +def row_linear_residual_norm_fusion_forward_legacy( x: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, @@ -29,14 +119,8 @@ def row_linear_residual_norm_fusion_forward( max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, ): - x = x.cuda() - residual = residual.cuda() - norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() def func( @@ -52,11 +136,7 @@ def func( ): # For both fused and unfused cases: shape = input.shape - - assert max_num_elements_mnnvl % hidden_size == 0 - input = input.view(-1, shape[-1]) - buffer_M = max_num_elements_mnnvl // hidden_size if enable_fusion: @@ -150,13 +230,55 @@ def func( """Helper function to run the core MNNVL AllReduce test logic""" +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + # Communicator used for passing data between ranks + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + def run_mnnvl_ar_full( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int, - explicit_workspace_bytes: int | None = None, + legacy_explicit_workspace_bytes: Optional[int] = None, + legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. @@ -168,17 +290,15 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + comm = MPI.COMM_WORLD # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + rank = comm.Get_rank() + world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") - - # Ensure we have exactly 2 ranks for this test if world_size < 2: pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") @@ -199,90 +319,82 @@ def run_mnnvl_ar_full( print( f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" ) - - tensor_parallel_size = world_size eps = 1e-5 - torch.manual_seed(42) + torch.manual_seed(42 + rank) # Track if this rank failed rank_failed = False failure_message = "" try: - # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list - # This workspace is sized for the maximum expected sequence length and can be reused within each list - # Each parameterized list gets its own fresh workspace allocation - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes - ) - ) - - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() - buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() - unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( - mapping.tp_rank - ) - - # Test each sequence length with the same workspace (reusing allocated buffers within this list) - for seq_len in seq_lens: - if rank == 0: - print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + if legacy_api: + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes ) + ) - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), - dtype=dtype, - device=torch.device("cuda"), + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + + else: + required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( + mapping.tp_size, + max(seq_lens), + hidden_size, + dtype, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") + workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( + mapping, required_workspace_bytes ) - # Each rank gets its slice of the input - x = x_full[rank, :, :] + test_data = [] + for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion + ) + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) + ) - # Compute reference output based on fusion mode - reference_output: Tuple[torch.Tensor, ...] = None - if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + for seq_len, x, residual, norm_weight, reference_output in test_data: + if rank == 0: print( - "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device - ) + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) + if legacy_api: + row_linear_residual_norm_fusion_forward_legacy( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + fusion, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, ) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - - reference_output = (norm_out, residual_out) else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - reference_output = (allreduce_result,) - - # Run the test with the same workspace - row_linear_residual_norm_fusion_forward( - x, - residual, - norm_weight, - eps, - hidden_size, - dtype, - mapping, - fusion, - reference_output, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - buffer_flags_mnnvl, - ) + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + mapping, + fusion, + reference_output, + workspace, + ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() @@ -295,6 +407,7 @@ def run_mnnvl_ar_full( rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) + print(traceback.format_exc()) # Gather failure status from all ranks for logging all_failures = MPI.COMM_WORLD.allgather(rank_failed) @@ -305,16 +418,16 @@ def run_mnnvl_ar_full( print(f"Test failed on ranks: {failed_ranks}") # Cleanup before re-raising - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() @@ -325,79 +438,28 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize( "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], ) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_default_workspace( +@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) +def test_mnnvl_allreduce_refactored( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with default workspace size.""" - run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) - - -"""Test with explicit workspace size""" + """Test MNNVL AllReduce with refactored API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=False + ) -@pytest.mark.parametrize( - "seq_lens", - [ - [1, 4, 180], - ], -) +@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]]) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_explicit_workspace( +def test_mnnvl_allreduce_legacy( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with explicitly calculated workspace size.""" - # Calculate workspace to fit the maximum sequence length - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) + """Test MNNVL AllReduce with legacy API.""" run_mnnvl_ar_full( - monkeypatch, - seq_lens, - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=explicit_workspace_bytes, + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True ) - - -"""Negative test: workspace too small""" - - -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096]) -def test_mnnvl_allreduce_workspace_too_small( - monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" - # Use a large sequence length that won't fit in a small workspace - seq_len = 180 - - # Create a workspace that's too small (only enough for 10 tokens) - small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - - # Expect a ValueError with a message about buffer_M being too small - with pytest.raises((ValueError, RuntimeError)) as exc_info: - run_mnnvl_ar_full( - monkeypatch, - [seq_len], - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=small_workspace_bytes, - ) - - # Verify the error message contains the expected text - assert "greater than the buffer_M" in str(exc_info.value)