From 63b7683e7015cf2afcdbe91795605fe637fa0f99 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Mon, 17 Nov 2025 17:21:31 -0800 Subject: [PATCH 01/14] Staging initial integration of kernel code. --- csrc/trtllm_mnnvl_allreduce.cu | 117 +- .../comm/trtllm_mnnvl_allreduce.cuh | 1483 +++++++++++------ include/flashinfer/utils.cuh | 12 + 3 files changed, 1085 insertions(+), 527 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 6bac5372a8..05a1684aa0 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,77 +26,84 @@ using tvm::ffi::Optional; } \ }() -void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, - int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, - int64_t rank, bool wait_for_results, bool launch_with_pdl, - Optional out) { - cudaSetDevice(in.device().device_id); - auto stream = get_stream(in.device()); +// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this +// level +void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, + int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, + TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, + bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, + TensorView output, Optional residual_out, + Optional gamma, Optional epsilon) { + cudaSetDevice(input.device().device_id); + auto stream = get_stream(input.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] { // Extract parameters from tensors - int64_t num_tokens = in.size(0); - int64_t token_dim = in.size(1); + int64_t num_tokens = input.size(0); + int64_t token_dim = input.size(1); // Validate input parameters - TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) - << "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type); + TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0) + << "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type); + TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1)) + << "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << output.size(0) << ", " << output.size(1) << ")"; TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64) << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK(out.has_value() || !wait_for_results) - << "out tensor must be provided if wait_for_results is true"; + TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + !rmsnorm_fusion) + << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + + if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_out.size(0) == num_tokens && residual_out.size(1) == token_dim) + << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_out.size(0) << ", " << residual_out.size(1) << ")"; + TVM_FFI_ICHECK(gamma.size(0) == token_dim) + << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" + << gamma.size(0) << ")"; + } // Create the parameters struct - AllReduceParams params; - params.nranks = nranks; - params.rank = rank; - params.buffer_M = buffer_M; - params.num_tokens = num_tokens; - params.token_dim = token_dim; - params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); - params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); - params.buffer_flags = buffer_flags_mnnvl.data_ptr(); - params.wait_for_results = wait_for_results; - params.launch_with_pdl = launch_with_pdl; - params.input = in.data_ptr(); - params.output = out.has_value() ? out.value().data_ptr() : nullptr; - params.stream = stream; + AllReduceFusionParams params; - auto status = twoshot_allreduce_dispatch_world_size(params); - TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); - }); -} + // Aux Information + params.nRanks = nranks; + params.rank = rank; + params.numTokens = num_tokens; + params.tokenDim = token_dim; + params.bufferPtrsDev = reinterpret_cast(buffer_ptrs_dev); + params.bufferPtrLocal = reinterpret_cast(buffer_ptr_local); + params.multicastPtr = reinterpret_cast(multicast_buffer_ptr); + params.bufferFlags = reinterpret_cast(buffer_flags_mnnvl.data_ptr()); + params.rmsNormFusion = rmsnorm_fusion; + params.launchWithPdl = launch_with_pdl; -void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, - TensorView normed_output, TensorView gamma, double epsilon, - TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { - cudaSetDevice(prenorm_output.device().device_id); - auto stream = get_stream(prenorm_output.device()); + // input data + params.input = const_cast(input.data_ptr()); + params.residualIn = residual_out.has_value() + ? const_cast(residual_out.value().data_ptr()) + : nullptr; + params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; + params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] { - // Create the parameters struct - RMSNormParams params; - params.residual_output = prenorm_output.data_ptr(); - params.output = normed_output.data_ptr(); - params.input = reinterpret_cast(multicast_buffer_ptr); - params.gamma = gamma.data_ptr(); - params.epsilon = epsilon; - params.residual = residual.data_ptr(); - params.buffer_flags = reinterpret_cast(buffer_flags.data_ptr()); - params.batch = normed_output.size(0); - params.hidden_dim = normed_output.size(1); + // output data + params.output = const_cast(output.data_ptr()); + params.residualOut = + residual_out.has_value() ? const_cast(residual_out.value().data_ptr()) : nullptr; params.stream = stream; - params.launch_with_pdl = launch_with_pdl; - auto status = twoshot_rmsnorm_dispatch_hidden_dim(params); + + cudaError_t status; + if (use_oneshot) { + status = oneshotAllreduceFusionDispatch(params); + } else { + status = twoshotAllreduceFusionDispatch(params); + } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_rmsnorm_dispatch_hidden_dim failed with error code " + << "twoshot_allreduce_dispatch_world_size failed with error code " << cudaGetErrorString(status); }); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion); diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 3dbed4b649..9198df8775 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -18,52 +18,54 @@ #include #include #include +#include #include +#include #include "../exception.h" #include "../logging.h" +#include "../utils.cuh" namespace flashinfer { namespace trtllm_mnnvl_allreduce { -template -struct AllReduceParams { - int nranks; +struct AllReduceFusionParams { + int nRanks; int rank; - int buffer_M; - int num_tokens; - int token_dim; - void** buffer_ptrs_dev; - void* multicast_ptr; - void* buffer_flags; - bool wait_for_results; - bool launch_with_pdl; - - void* input; - void* output; - cudaStream_t stream; -}; + int numTokens; + int tokenDim; + void** bufferPtrsDev; + void* bufferPtrLocal; + void* multicastPtr; + uint32_t* bufferFlags; + bool rmsNormFusion; + bool launchWithPdl; -template -struct RMSNormParams { - void* residual_output; - void* output; void const* input; + void const* residualIn; void const* gamma; double epsilon; - void* residual; - uint32_t* buffer_flags; - int batch; - int hidden_dim; - cudaStream_t stream; - bool launch_with_pdl; + + void* residualOut; + void* output; + cudaStream_t stream = nullptr; }; -__device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } +namespace utils { + +constexpr uint16_t kNEGZERO_FP16 = 0x8000U; + +template +union Fp16BitCast { + T mFp; + uint16_t mInt; + + constexpr Fp16BitCast() : mInt(0) {} -__device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } + constexpr Fp16BitCast(T val) : mFp(val) {} -__device__ bool isNegZero(__nv_half val) { return isNegZero(__half2float(val)); } + constexpr Fp16BitCast(uint16_t val) : mInt(val) {} +}; template inline __device__ float toFloat(T val) { @@ -74,7 +76,6 @@ template <> inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } - template <> inline __device__ float toFloat<__nv_half>(__nv_half val) { return __half2float(val); @@ -95,581 +96,1119 @@ inline __device__ __nv_half fromFloat<__nv_half>(float val) { return __float2half(val); } -inline __device__ float2 loadfloat2(void const* ptr) { - float2 return_value; - asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" - : "=f"(return_value.x), "=f"(return_value.y) - : "l"(ptr)); - return return_value; +template +static constexpr __device__ __host__ T negZero() { + if constexpr (std::is_same_v) { + return -0.0F; + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(kNEGZERO_FP16).mFp; + } else { + static_assert(sizeof(T) == 0, "negativeZero not specialized for this type"); + } + return T{}; // Never reached, but needed for compilation } template -inline __device__ T divUp(T val, T divisor) { - return (val + divisor - 1) / divisor; +static inline __device__ bool isNegZero(T val) { + if constexpr (std::is_same_v) { + return val == 0.F && signbit(val); + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(val).mInt == kNEGZERO_FP16; + } else { + static_assert(sizeof(T) == 0, "isNegZero not specialized for this type"); + } + return false; // Never reached, but needed for compilation } +template +constexpr __device__ __host__ PackedType getPackedLamportInit() { + static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size"); + constexpr int kNumElements = sizeof(PackedType) / sizeof(T); + + union PackedT { + PackedType mPacked; + std::array mElements; + + constexpr PackedT() : mElements{} { + for (int i = 0; i < kNumElements; i++) { + mElements[i] = negZero(); + } + } + }; + + PackedT initValue{}; + return initValue.mPacked; +} + +// A helper class to get the correct base pointer for a given layout +struct LamportBufferLayout { + uint32_t numStages = 1; + uint32_t bytesPerBuffer = 0; + static constexpr uint32_t sNumLamportBuffers = 3; + + // Implicitly inlined + [[nodiscard]] __device__ __host__ size_t getTotalBytes() const { + return numStages * static_cast(bytesPerBuffer / numStages) * sNumLamportBuffers; + } + + // Implicitly inlined + [[nodiscard]] __device__ __host__ void* getStagePtr(void* bufferBasePtr, uint32_t lamportIndex, + uint32_t stageIndex) const { + // Typecast to avoid warnings + return reinterpret_cast( + reinterpret_cast(bufferBasePtr) + + static_cast((lamportIndex * numStages + stageIndex) * + static_cast(bytesPerBuffer / numStages))); + } +}; +// Current Index +// Dirty Index +// bytes_per_buffer +// Dirty num_stages +// Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages +// offset_access_ptr + +namespace cg = cooperative_groups; + +// PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64) +template __device__ struct __attribute__((aligned(32))) LamportFlags { - uint32_t buffer_size; - uint32_t input_offset; - uint32_t clear_offset; - uint32_t num_tokens_prev; - uint32_t* offset_access_ptr; - uint32_t* buffer_flags; - - __device__ explicit LamportFlags(uint32_t* buffer_flags) - : offset_access_ptr(&buffer_flags[4]), buffer_flags(buffer_flags) { - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; - input_offset = flag.x * (buffer_size << 1U); - clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; - } - - __device__ void cta_arrive() { + public: + __device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1) + : mBufferFlagsPtr(bufferFlags), mFlagAccessPtr(&bufferFlags[8]) { + mCurBufferLayout.numStages = numStages; + uint4 flag = reinterpret_cast(bufferFlags)[0]; + mCurrentIndex = flag.x; + mDirtyIndex = flag.y; + // Buffer size is unchanged as the flag should be coupled to each buffer + mCurBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.numStages = flag.w; + *reinterpret_cast(&mBytesToClear) = reinterpret_cast(bufferFlags)[1]; + } + + // Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx + [[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const { + return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx); + } + + // Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, + // -1 to clear all + // FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the + // correctness? CAUTION: This function requires all threads in the grid to participate and ASSUME + // 1D thread block layout! + __device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1) { + // Rasterize the threads to 1D for flexible clearing + + uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y; + uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x; + uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x; + + if (stageIdx == -1) { + // Clear all stages + for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i); + } + } else if (stageIdx < mDirtyBufferLayout.numStages) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, + stageIdx); + } + } + + __device__ void ctaArrive() { + int tid{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + cg::cluster_group cluster = cg::this_cluster(); + // We update the atomic counter per cluster + tid = cluster.thread_rank(); + cluster.sync(); +#else + tid = threadIdx.x; __syncthreads(); - if (threadIdx.x == 0) { +#endif + if (tid == 0) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); #else - atomicAdd(offset_access_ptr, 1); + atomicAdd(mFlagAccessPtr, 1); #endif } } - __device__ void wait_and_update(uint32_t num_tokens) { - if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) { - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + __device__ void waitAndUpdate(uint4 bytesToClearPerStage) { + bool isLastCtaT0{false}; + int targetCount{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cg::grid_group grid = cg::this_grid(); + // Use the first thread instead of the last thread as the last thread may exit early + isLastCtaT0 = grid.thread_rank() == 0; + targetCount = grid.num_clusters(); +#else + isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0; + targetCount = gridDim.x * gridDim.y; +#endif + if (isLastCtaT0) { + uint4* flagPtr = reinterpret_cast(mBufferFlagsPtr); + while (*reinterpret_cast(mFlagAccessPtr) < targetCount) { } - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; - *(offset_access_ptr) = 0; + // 'Current' becomes 'Dirty' + flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index + mCurrentIndex, // Dirty index + mCurBufferLayout.bytesPerBuffer, // Buffer size + mCurBufferLayout.numStages}; // Dirty - Number of stages + flagPtr[1] = bytesToClearPerStage; + *mFlagAccessPtr = 0; + } + } + + private: + uint32_t* mBufferFlagsPtr; + uint32_t* mFlagAccessPtr; + + uint32_t mCurrentIndex, mDirtyIndex; + // So that we can access it with uint4 + alignas(16) std::array mBytesToClear; + LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout; + + inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, + uint32_t numThreads, uint32_t bytesToClear, + uint8_t dirtyIndex, uint8_t stageIdx) { + // Round up to the float4 boundary + uint32_t clearBoundary = ceil_div(bytesToClear, sizeof(PackedType)); + for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads) { + reinterpret_cast( + mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx] = + getPackedLamportInit(); + } + } +}; + +template +union PackedVec { + PackedType packed; + T elements[sizeof(PackedType) / sizeof(T)]; + + __device__ PackedVec& operator+=(PackedVec& other) { +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + elements[i] += other.elements[i]; + } + return *this; + } + + __device__ PackedVec operator+(PackedVec& other) { + PackedVec result; +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + result.elements[i] = elements[i] + other.elements[i]; } + return result; } }; -template -__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, - int num_tokens, int buffer_M, int token_dim, int rank, - uint32_t* buffer_flags, bool wait_for_results) { - int elt = blockIdx.y * blockDim.x + threadIdx.x; +template +inline __device__ PackedType loadPacked(T* ptr) { + return *reinterpret_cast(ptr); +} + +template +inline __device__ const PackedType loadPacked(T const* ptr) { + return *reinterpret_cast(ptr); +} + +template +inline __device__ PackedType loadPackedVolatile(void const* ptr) { + static_assert(sizeof(PackedType) == 0, "Not implemented"); + return PackedType{}; +} + +template <> +inline __device__ float4 loadPackedVolatile(void const* ptr) { + float4 returnValue; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(returnValue.x), "=f"(returnValue.y), "=f"(returnValue.z), "=f"(returnValue.w) + : "l"(ptr)); + return returnValue; +} + +template <> +inline __device__ float2 loadPackedVolatile(void const* ptr) { + float2 returnValue; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" + : "=f"(returnValue.x), "=f"(returnValue.y) + : "l"(ptr)); + return returnValue; +} - if (elt >= token_dim) return; +template +inline __device__ void copyF4(T_IN* dst, T_IN const* src) { + float4* dst4 = reinterpret_cast(dst); + float4 const* src4 = reinterpret_cast(src); + __pipeline_memcpy_async(dst4, src4, sizeof(float4)); +} + +uint32_t constexpr kWARP_SIZE = 32U; +uint32_t constexpr kLOG2_WARP_SIZE = 5U; +uint32_t constexpr kLANE_ID_MASK = 0x1f; +uint32_t constexpr kFINAL_MASK = 0xffffffff; + +template +inline __device__ T warpReduceSumFull(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(kFINAL_MASK, val, mask, kWARP_SIZE); + } + return val; +} + +template +inline __device__ T warpReduceSumPartial(T val) { + int laneId = threadIdx.x & kLANE_ID_MASK; + // We make sure only the last warp will call this function + int warpSize = blockDim.x - (threadIdx.x & ~(kWARP_SIZE - 1)); + unsigned int active_mask = (1U << warpSize) - 1; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + int targetLane = laneId ^ mask; + auto tmp = __shfl_xor_sync(active_mask, val, mask, kWARP_SIZE); + val += targetLane < warpSize ? tmp : 0; + } + return val; +} + +// SYNC: +// - True: share the sum across all threads +// - False: only thread 0 get the sum; Other thread's value is undefined. +template +inline __device__ T blockReduceSumPartial(T val) { + __shared__ T smem[kWARP_SIZE]; + int laneId = threadIdx.x & kLANE_ID_MASK; + int warpId = threadIdx.x >> kLOG2_WARP_SIZE; + int warpNum = (blockDim.x + kWARP_SIZE - 1) >> + kLOG2_WARP_SIZE; // Ceiling division to include partial warps + + val = (warpId == warpNum - 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + if (laneId == 0) { + smem[warpId] = val; + } + __syncthreads(); + + if (warpId == 0) { + val = (laneId < warpNum) ? smem[laneId] : (T)0.f; + // Need to consider the corner case where we only have one warp and it is partial + val = (warpNum == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + + if constexpr (SYNC) { + if (laneId == 0) { + smem[warpId] = val; + } + } + } + if constexpr (SYNC) { + __syncthreads(); + val = smem[0]; + } + return val; +} + +template +inline __device__ T blockReduceSumFull(T val) { + __shared__ T smem[kWARP_SIZE]; + int lane_id = threadIdx.x & kLANE_ID_MASK; + int warp_id = threadIdx.x >> kLOG2_WARP_SIZE; + int warp_num = blockDim.x >> kLOG2_WARP_SIZE; + + val = warpReduceSumFull(val); + if (lane_id == 0) { + smem[warp_id] = val; + } + __syncthreads(); + + val = (lane_id < warp_num) ? smem[lane_id] : (T)0.f; + val = warpReduceSumFull(val); + + return val; +} + +template +inline __device__ T blockReduceSum(T val) { + bool hasPartialWarp = (blockDim.x & kLANE_ID_MASK) != 0; + if (hasPartialWarp) { + return blockReduceSumPartial(val); + } else { + return blockReduceSumFull(val); + } +} +// A helper function to tune the grid configuration for fused oneshot and rmsnorm kernels +// Return (block_size, cluster_size, loads_per_thread) +std::tuple adjustGridConfig(int numTokens, int dim, int eltsPerThread) { + // Start with preferred block_size and cluster_size +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int clusterSize = 8; +#else + int clusterSize = 1; +#endif + int blockSize = 128; + // ========================== Adjust the grid configuration ========================== + int threadsNeeded = ceil_div(dim, eltsPerThread); + int loadsPerThread = 1; + + blockSize = ceil_div(threadsNeeded, clusterSize); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + while (threadsNeeded % clusterSize != 0 && clusterSize > 1) { + clusterSize /= 2; + } + blockSize = ceil_div(threadsNeeded, clusterSize); + while (blockSize < 128 && clusterSize >= 2) { + blockSize *= 2; + clusterSize /= 2; + } + int smCount = GetCudaMultiProcessorCount(); + while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { + blockSize *= 2; + clusterSize /= 2; + } +#endif + + // Trying to scale up use multiple loads or CGA + while (blockSize > 1024) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (clusterSize < 8) { + clusterSize = clusterSize << 1; + } else { + break; + } +#else + if (loadsPerThread < 8) { + loadsPerThread += 1; + } else { + break; + } +#endif + blockSize = ceil_div(threadsNeeded, clusterSize * loadsPerThread); + } + return {blockSize, clusterSize, loadsPerThread}; +} +}; // namespace utils + +using utils::blockReduceSum; +using utils::fromFloat; +using utils::isNegZero; +using utils::LamportFlags; +using utils::loadPacked; +using utils::loadPackedVolatile; +using utils::PackedVec; +using utils::toFloat; + +template +__global__ void __launch_bounds__(1024) + oneshotAllreduceFusionKernel(T* outputPtr, T* prenormedPtr, T const* shardPtr, + T const* residualInPtr, T const* gammaPtr, T** inputPtrs, + T* mcastPtr, int const numTokens, int const tokenDim, + float epsilon, int const rank, uint32_t* bufferFlags) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int packedIdx = cluster.thread_rank(); int token = blockIdx.x; + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); +#else + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; #endif - LamportFlags flags(buffer_flags); - - // Capture the number of tokens in previous iteration so that we can properly clear the buffer - // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up - uint32_t clr_toks_cta = - divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, - WORLD_SIZE) * - WORLD_SIZE; - clr_toks_cta = divUp(clr_toks_cta, gridDim.x); - - if (elt < token_dim) { - // Scatter token - int dest_rank = token % WORLD_SIZE; - int dest_token_offset = token / WORLD_SIZE; - T val = shard_ptr[token * token_dim + elt]; - if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + - rank * token_dim + elt] = val; - - // Clear the buffer used by the previous call. Note the number of tokens to clear could be - // larger than the - // number of tokens in the current call. - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + // We only use 1 stage for the oneshot allreduce + LamportFlags flag(bufferFlags, 1); + T* stagePtrMcast = reinterpret_cast(flag.getCurLamportBuf(mcastPtr, 0)); + T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); + + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } + + // ==================== Broadcast tokens to each rank ============================= + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) val.elements[i] = toFloat(0.f); + } + + reinterpret_cast( + &stagePtrMcast[token * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = val.packed; + + flag.ctaArrive(); + // ======================= Lamport Sync and clear the output buffer from previous iteration + // ============================= + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); + +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); } } + if (valid) { + break; + } + } - // Reduce and broadcast - if ((token % WORLD_SIZE) == rank) { - int local_token = token / WORLD_SIZE; - float accum = 0.f; - - T values[WORLD_SIZE]; - - while (1) { - bool valid = true; - for (int r = 0; r < WORLD_SIZE; r++) { - T volatile* lamport_ptr = - (T volatile*)&input_ptrs[rank] - [flags.input_offset + local_token * token_dim * WORLD_SIZE + - r * token_dim + elt]; - values[r] = *lamport_ptr; - valid &= !isNegZero(values[r]); - } - if (valid) break; - } - for (int r = 0; r < WORLD_SIZE; r++) { - accum += toFloat(values[r]); - } - mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = - fromFloat(accum); + auto values = reinterpret_cast*>(valuesLamport); + // ======================= Reduction ============================= + float accum[kELTS_PER_THREAD]; + PackedVec packedAccum; + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] = toFloat(values[0].elements[i]); + } + +#pragma unroll + for (int r = 1; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); } } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + - elt] = fromFloat(-0.f); + if constexpr (RMSNormFusion) { + // =============================== Residual =============================== + PackedVec residualIn; + residualIn.packed = *reinterpret_cast(&residualInPtr[threadOffset]); + packedAccum += residualIn; + *reinterpret_cast(&prenormedPtr[threadOffset]) = packedAccum.packed; + // =============================== Rmsnorm ================================ + PackedVec gamma; + gamma.packed = *reinterpret_cast(&gammaPtr[packedIdx * kELTS_PER_THREAD]); + + float threadSum = 0.F; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + // FIXME: Use float square if accuracy issue + threadSum += toFloat(packedAccum.elements[i] * packedAccum.elements[i]); } - } + float blockSum = blockReduceSum(threadSum); - // Optionally wait for results if the next layer isn't doing the Lamport check - if (wait_for_results) { - // Update the atomic counter to indicate the block has read the offsets - flags.cta_arrive(); - // Only use a set of CTAs for lamport sync, reargange the grid - constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); - // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) - if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = - blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = - (void*)&input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be - // aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*)&val)) { - val = loadfloat2(lamport_ptr); + __shared__ float sharedVal[8]; // Temporary variable to share the sum within block + float fullSum = blockSum; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; } - if (output_ptr) { - *((float2*)&output_ptr[current_pos]) = val; + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; } } - - // Update the buffer flags - flags.wait_and_update(num_tokens); +#endif + float rcpRms = rsqrtf(fullSum / tokenDim + epsilon); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * + fromFloat(gamma.elements[i])); + } } + reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; + flag.waitAndUpdate( + {static_cast(numTokens * tokenDim * WorldSize * kELT_SIZE), 0, 0, 0}); } -// Template-based dispatch functions following the same pattern as trtllm_allreduce.cuh -template -cudaError_t twoshot_allreduce_dispatch(AllReduceParams& params) { - int const num_threads = 128; - int const num_blocks = (params.token_dim + num_threads - 1) / num_threads; - - dim3 grid(params.num_tokens, num_blocks); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.dynamicSmemBytes = 0; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = num_threads; - config.attrs = attrs; +using utils::adjustGridConfig; + +template +cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const eltsPerThread = sizeof(float4) / sizeof(T); + + auto [blockSize, clusterSize, loadsPerThread] = + adjustGridConfig(numTokens, tokenDim, eltsPerThread); + dim3 grid(numTokens, clusterSize, 1); + + FLASHINFER_CHECK(blockSize <= 1024 && loadsPerThread == 1, + "Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", + tokenDim, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + 1024 * 8 * eltsPerThread); +#else + 1024 * eltsPerThread); +#endif + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: " + "%d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, + ceil_div(tokenDim, eltsPerThread)); + + cudaLaunchAttribute attrs[2]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; + attrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + attrs[1].id = cudaLaunchAttributeClusterDimension; + attrs[1].val.clusterDim.x = 1; + attrs[1].val.clusterDim.y = clusterSize; + attrs[1].val.clusterDim.z = 1; +#endif - cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.buffer_ptrs_dev), - reinterpret_cast(params.multicast_ptr), params.num_tokens, params.buffer_M, - params.token_dim, params.rank, - reinterpret_cast(params.buffer_flags), params.wait_for_results); + cudaLaunchConfig_t config{ + .gridDim = grid, + .blockDim = blockSize, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = attrs, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + .numAttrs = 2, +#else + .numAttrs = 1, +#endif + }; + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, RMSNORM) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &config, &oneshotAllreduceFusionKernel, output, residualOut, input, \ + residualIn, gamma, ucPtrs, mcPtr, numTokens, tokenDim, static_cast(params.epsilon), \ + params.rank, params.bufferFlags)); +#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + if (params.rmsNormFusion) { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true); \ + } else { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false); \ + } - return cudaSuccess; -} + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T* residualOut = reinterpret_cast(params.residualOut); + T const* input = reinterpret_cast(params.input); + T const* residualIn = reinterpret_cast(params.residualIn); + T const* gamma = reinterpret_cast(params.gamma); -template -cudaError_t twoshot_allreduce_dispatch_world_size(AllReduceParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_allreduce_dispatch_world_size"); - switch (params.nranks) { + switch (params.nRanks) { + // FIXME: Do we need other world sizes? case 2: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(2); case 4: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(4); case 8: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(8); case 16: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(16); case 32: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(32); case 64: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(64); default: FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } +#undef LAUNCH_ALLREDUCE_KERNEL + return cudaSuccess; } -template -__device__ void copy_f4(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4 const*)src; - __pipeline_memcpy_async(dst4, src4, sizeof(float4)); -} - -template -__device__ void copy_f4_ldg(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4*)src; - *dst4 = *src4; -} - -__device__ float4 loadfloat4(void const* ptr) { - // Check alignment - ptr should be 16-byte aligned for safe float4 load - if (reinterpret_cast(ptr) % 16 != 0) { - // Fall back to scalar loads if not aligned - float4 return_value; - float const* float_ptr = reinterpret_cast(ptr); - return_value.x = float_ptr[0]; - return_value.y = float_ptr[1]; - return_value.z = float_ptr[2]; - return_value.w = float_ptr[3]; - return return_value; - } +enum MNNVLTwoShotStage : uint8_t { + SCATTER = 0, + BROADCAST = 1, + NUM_STAGES = 2, +}; - float4 return_value; +template +__global__ __launch_bounds__(128) void twoshotAllreduceKernel( + T* outputPtr, T const* shardPtr, T** inputPtrs, T* mcastPtr, uint32_t const numTokens, + uint32_t const tokenDim, uint32_t const rank, uint32_t* bufferFlags, + bool const wait_for_results) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), - "=f"(return_value.w) - : "l"(ptr)); + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; - return return_value; -} + int destRank = token % WorldSize; + int destTokenOffset = token / WorldSize; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); -// Safer version that checks bounds before loading -template -__device__ float4 loadfloat4_safe(T const* ptr, int remaining_elements) { - float return_value[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + T* scatterBufLocal = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER)); + T* scatterBufDest = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[destRank], MNNVLTwoShotStage::SCATTER)); + T* broadcastBufW = + reinterpret_cast(flag.getCurLamportBuf(mcastPtr, MNNVLTwoShotStage::BROADCAST)); + T* broadcastBufR = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST)); - if (remaining_elements <= 0) { - return *(float4*)return_value; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Make sure the clear function is called before OOB thread exits + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; } - // Check alignment - ptr should be 16-byte aligned for safe float4 load - bool is_aligned = (reinterpret_cast(ptr) % 16 == 0); + // =============================== Scatter =============================== - if (is_aligned && remaining_elements >= 4) { - // Safe to do vectorized load - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), - "=f"(return_value[3]) - : "l"(ptr)); - } else { - // Fall back to scalar loads with bounds checking - float const* float_ptr = reinterpret_cast(ptr); - for (int i = 0; i < 4 && i < remaining_elements; i++) { - return_value[i] = toFloat(float_ptr[i]); + // Load vectorized data + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) { + val.elements[i] = fromFloat(0.F); } } - return *(float4*)return_value; -} + // Store vectorized data + reinterpret_cast( + &scatterBufDest[destTokenOffset * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = + val.packed; -template -inline __device__ T add(T a, T b) { - return a + b; -} + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); -#define FINAL_MASK 0xffffffff + // =============================== Reduction and Broadcast =============================== -template -__inline__ __device__ T warpReduceSum(T val) { + if ((token % WorldSize) == rank) { + int localToken = token / WorldSize; + float accum[kELTS_PER_THREAD] = {0.F}; + + // Use float as we only check each float value for validity + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, - 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &scatterBufLocal[localToken * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); -inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; - val = warpReduceSum(val); - if (lane_id == 0) { - smem[warp_id] = val; - } - __syncthreads(); - val = lane_id < warp_num ? smem[lane_id] : 0.f; - val = warpReduceSum(val); - return val; -} + // Check validity across all elements +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); + } + } + if (valid) { + break; + } + } -template -__global__ void __launch_bounds__(128, 1) - RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, - T_IN const* gamma, float epsilon, T_IN const* residual, int batch_size, - uint32_t* buffer_flags) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Now we view it as the value for reduction + auto values = reinterpret_cast*>(valuesLamport); +#pragma unroll + for (int r = 0; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } - static bool const LAMPORT = true; + // Store vectorized result + PackedVec packedAccum; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } + reinterpret_cast(&broadcastBufW[token * tokenDim])[packedIdx] = packedAccum.packed; + } - extern __shared__ uint8_t smem[]; + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST); - int sample = blockIdx.y; + // Optionally wait for results if the next layer isn't doing the Lamport check + if (wait_for_results) { + // Update the atomic counter to indicate the block has read the offsets + flag.ctaArrive(); - static int const CGA_THREADS = NUM_THREADS * 1; + PackedVec valLamport; + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + while (isNegZero(valLamport.elements[0])) { + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + } + if (outputPtr) { + reinterpret_cast(&outputPtr[threadOffset])[0] = valLamport.packed; + } - static int const ITERS = DIM / CGA_THREADS; - float r_input[ITERS]; - float r_gamma[ITERS]; + // Update the buffer flags + flag.waitAndUpdate( + {static_cast(round_up(numTokens, WorldSize) * tokenDim * + kELT_SIZE), // Clear Size for scatter stage + static_cast(numTokens * tokenDim * kELT_SIZE), // Clear Size for broadcast stage + 0, 0}); + // If not wait for results, we will rely on the following kernel to update the buffer + } +} - T_IN* sh_input = (T_IN*)&smem[0]; - T_IN* sh_residual = (T_IN*)&smem[NUM_INPUTS * NUM_THREADS * ITERS * sizeof(T_IN)]; - T_IN* sh_gamma = (T_IN*)&smem[(NUM_INPUTS + 1) * NUM_THREADS * ITERS * sizeof(T_IN)]; +using utils::copyF4; +// This kernel works performant when loads_per_thread is 1. +// For this mode, we are able to support up to 1024 (threads) x 8 (elements) = 8192 hidden +// dimension. There are two options for further scaling up: +// 1. Use CGA if supported. It expands the hidden dimension to 8k x 8 = 64k. +// 2. Set loads_per_thread >1. Which can be used if CGA is not supported. Note that this will +// be limited by the shared memory size and register count. +template +__global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OUT* outputNorm, + T_IN* bufferInput, T_IN const* gamma, + float epsilon, T_IN const* residual, + uint32_t numTokens, uint32_t dim, + uint32_t worldSize, uint32_t* bufferFlags) { + static_assert(std::is_same_v, "T_IN and T_OUT must be the same type"); + static int const kELTS_PER_LOAD = sizeof(float4) / sizeof(T_IN); + + uint32_t const token = blockIdx.x; + uint32_t const blockSize = blockDim.x; + uint32_t const threadOffset = threadIdx.x; + + uint32_t numThreads = blockSize; + uint32_t clusterSize = 1; + uint32_t blockOffset = 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + numThreads = cluster.num_threads(); + clusterSize = cluster.num_blocks(); + blockOffset = cluster.block_rank(); +#endif + uint32_t const dimPadded = round_up(dim, kELTS_PER_LOAD * numThreads); + uint32_t const elemsPerThread = dimPadded / numThreads; + uint32_t const loadStride = blockSize; - static int const ELTS_PER_THREAD = sizeof(float4) / sizeof(T_IN); + extern __shared__ uint8_t smem[]; + float rInput[LoadsPerThread * kELTS_PER_LOAD]; + uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD]; - int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; + uint32_t const smemBufferSize = blockSize * elemsPerThread * sizeof(T_IN); + T_IN* smemInput = (T_IN*)&smem[0]; + T_IN* smemResidual = (T_IN*)&smem[smemBufferSize]; + T_IN* smemGamma = (T_IN*)&smem[2 * smemBufferSize]; - LamportFlags flags(buffer_flags); - T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); + T_IN* input = reinterpret_cast( + flag.getCurLamportBuf(reinterpret_cast(bufferInput), MNNVLTwoShotStage::BROADCAST)); -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif + // The offset that current thread should load from. Note that the hidden dimension is split by CGA + // size and each block loads a contiguous chunk; The size of chunk that each block processes + uint32_t const blockChunkSize = ceil_div(dim, clusterSize * kELTS_PER_LOAD) * kELTS_PER_LOAD; + uint32_t const blockLoadOffset = token * dim + blockOffset * blockChunkSize; - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - offsets[i][j] = - i * batch_size * DIM + sample * DIM + blockIdx.x * DIM / 1 + k * ELTS_PER_THREAD; - } +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + // Each block load a contiguous chunk of tokens + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + offsets[i] = blockLoadOffset + threadLoadOffset; } #pragma unroll - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_residual[i * ELTS_PER_THREAD], - &residual[sample * DIM + blockIdx.x * DIM + i * ELTS_PER_THREAD]); + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemResidual[threadLoadOffset], &residual[blockLoadOffset + threadLoadOffset]); + } } - __pipeline_commit(); - #pragma unroll - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_gamma[i * ELTS_PER_THREAD], &gamma[blockIdx.x * DIM + i * ELTS_PER_THREAD]); + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemGamma[threadLoadOffset], &gamma[blockOffset * blockChunkSize + threadLoadOffset]); + } } - __pipeline_commit(); - flags.cta_arrive(); - // Load all inputs + flag.ctaArrive(); bool valid = false; - -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if (!LAMPORT) cudaGridDependencySynchronize(); -#endif - + // ACQBLK if not lamport while (!valid) { valid = true; #pragma unroll - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - - float4* dst4 = (float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - - // Calculate the absolute element offset from the start of buffer_input - int element_offset = offsets[i][j]; - - // The input pointer is already offset to: &buffer_input[buffer_offset + buffer_size] - // So the actual pointer we're accessing is: input + element_offset - // Which equals: &buffer_input[buffer_offset + buffer_size + element_offset] - - float4* src4 = (float4*)&input[element_offset]; - - float4 value; - // Check if we have enough elements remaining for a safe float4 load - if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= flags.buffer_size) { - value = loadfloat4(src4); - } else { - // Use safe load for boundary cases or out-of-bounds - int remaining_elements = flags.buffer_size - element_offset; - if (remaining_elements <= 0) { - // Completely out of bounds, return zeros - float4 return_value = {0.0f, 0.0f, 0.0f, 0.0f}; - value = return_value; - } else { - value = loadfloat4_safe(reinterpret_cast(src4), remaining_elements); - } - } - - if (LAMPORT) { - // Assume that the 16B were written atomically, so we only need to check one value - T_IN lowest_val = *(T_IN*)&value; - valid &= !isNegZero(lowest_val); - } - *dst4 = value; - } - } - } - - __syncthreads(); - - // Perform the initial input reduction - if (NUM_INPUTS > 0) { - T_IN accum[ELTS_PER_THREAD]; - float4* accum4 = (float4*)&accum; - - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; - *accum4 = *(float4*)&sh_input[k * ELTS_PER_THREAD]; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + float4* dst4 = reinterpret_cast(&smemInput[threadLoadOffset]); + float4 const* src4 = reinterpret_cast(&input[offsets[i]]); - for (int i = 1; i < NUM_INPUTS; i++) { - float4 data = *(float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - T_IN* p_d = (T_IN*)&data; - for (int x = 0; x < ELTS_PER_THREAD; x++) { - accum[x] += p_d[x]; - } + float4 value = loadPackedVolatile(src4); + // Assume that the 16B were written atomically, so we only need to check one value + valid &= !isNegZero(value.x); + *dst4 = value; } - - // Write back to input 0's staging location. No sync needed since all data localized to - // thread. - *(float4*)&sh_input[k * ELTS_PER_THREAD] = *accum4; } } - // Wait for residual __pipeline_wait_prior(1); __syncthreads(); - float thread_sum = 0.f; - + float threadSum = 0.f; #pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 inp4 = - *(float4*)&sh_input[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - float4 res4 = - *(float4*)&sh_residual[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - - T_IN* r_inp = (T_IN*)&inp4; - T_IN* r_res = (T_IN*)&res4; + for (int i = 0; i < LoadsPerThread; i++) { + int threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec inp{.packed = loadPacked(&smemInput[threadLoadOffset])}; + PackedVec res{.packed = loadPacked(&smemResidual[threadLoadOffset])}; - float4 out4; - - T_IN* r_out = (T_IN*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - - T_IN inp_plus_resid = r_inp[ii] + r_res[ii]; - r_out[ii] = inp_plus_resid; - r_input[i] = toFloat(inp_plus_resid); + PackedVec inp_plus_res = inp + res; +#pragma unroll + for (int j = 0; j < kELTS_PER_LOAD; j++) { + rInput[i * kELTS_PER_LOAD + j] = toFloat(inp_plus_res.elements[j]); + threadSum += toFloat(inp_plus_res.elements[j] * inp_plus_res.elements[j]); + } - // Accumulate the squares for RMSNorm - thread_sum += toFloat(inp_plus_resid * inp_plus_resid); + *reinterpret_cast(&outputPreNorm[blockLoadOffset + threadLoadOffset]) = + inp_plus_res.packed; } - - *(float4*)&input_plus_residual[sample * DIM + blockIdx.x * DIM + - io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; } - // Wait for Gamma. There will be a global synchronization as part of the reduction __pipeline_wait_prior(0); - float cluster_sum = block_reduce_sum(thread_sum); - - float rcp_rms = rsqrtf(cluster_sum / DIM + epsilon); + float blockSum = blockReduceSum(threadSum); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 gamma4 = - *(float4*)&sh_gamma[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - T_IN* r_g4 = (T_IN*)&gamma4; - - float4 out4; - // FIXME: this only works if T_OUT == T_IN - T_OUT* r_out = (T_OUT*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - r_gamma[i] = toFloat(r_g4[ii]); - r_out[ii] = fromFloat(r_gamma[i] * r_input[i] * rcp_rms); + float fullSum = blockSum; + __shared__ float sharedVal[8]; + // Use CGA Reduction if supported +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; } - - *(float4*)&output_norm[sample * DIM + blockIdx.x * DIM + io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; } - // Update the buffer pointers - flags.wait_and_update(batch_size); #endif -} -template -cudaError_t twoshot_rmsnorm_dispatch(RMSNormParams& params) { - static constexpr int NUM_THREADS = 128; - static constexpr int CGA_THREADS = NUM_THREADS; - constexpr int iters = H_DIM / CGA_THREADS; + float rcpRms = rsqrtf(fullSum / dim + epsilon); - dim3 grid(1, params.batch, 1); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = NUM_THREADS; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; - - size_t shmem_size = 3 * NUM_THREADS * iters * sizeof(T); - config.dynamicSmemBytes = shmem_size; +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + PackedVec r_out; + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec gamma = {.packed = loadPacked(&smemGamma[threadLoadOffset])}; - cudaFuncSetAttribute(&RMSNorm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); +#pragma unroll + for (uint32_t j = 0; j < kELTS_PER_LOAD; j++) { + r_out.elements[j] = fromFloat(toFloat(gamma.elements[j]) * + rInput[i * kELTS_PER_LOAD + j] * rcpRms); + } - cudaLaunchKernelEx( - &config, &RMSNorm, reinterpret_cast(params.residual_output), - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.gamma), static_cast(params.epsilon), - reinterpret_cast(params.residual), params.batch, params.buffer_flags); + *reinterpret_cast(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed; + } + } + constexpr int kELTS_SIZE = sizeof(T_IN); - return cudaSuccess; + // Update the buffer pointers + flag.waitAndUpdate({static_cast(round_up(numTokens, worldSize) * dim * kELTS_SIZE), + static_cast(numTokens * dim * kELTS_SIZE), 0, 0}); } template -cudaError_t twoshot_rmsnorm_dispatch_hidden_dim(RMSNormParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_rmsnorm_dispatch_hidden_dim"); - switch (params.hidden_dim) { - case 2048: - return twoshot_rmsnorm_dispatch(params); - case 4096: - return twoshot_rmsnorm_dispatch(params); - case 5120: - return twoshot_rmsnorm_dispatch(params); // Llama-4 - case 7168: - return twoshot_rmsnorm_dispatch(params); // DeepSeek - case 8192: - return twoshot_rmsnorm_dispatch(params); +cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const numEltsPerThread = sizeof(float4) / sizeof(T); + FLASHINFER_CHECK(tokenDim % numEltsPerThread == 0, + "[MNNVL AllReduceTwoShot] token_dim must be divisible by %d", numEltsPerThread); + + int const arNumThreads = ceil_div(tokenDim, numEltsPerThread); + int const arNumBlocksPerToken = ceil_div(arNumThreads, 128); + + dim3 arGrid(numTokens, arNumBlocksPerToken); + + cudaLaunchAttribute arAttrs[1]; + arAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + arAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; + + cudaLaunchConfig_t arConfig{ + .gridDim = arGrid, + .blockDim = 128, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = arAttrs, + .numAttrs = 1, + }; + + FLASHINFER_LOG_DEBUG("[MNNVL AllReduceTwoShot] Dispatch: grid size: (%d, %d, 1), block_size: 128", + numTokens, arNumBlocksPerToken); + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &arConfig, &twoshotAllreduceKernel, output, input, ucPtrs, mcastPtr, \ + numTokens, tokenDim, params.rank, params.bufferFlags, (!params.rmsNormFusion))); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcastPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T const* input = reinterpret_cast(params.input); + switch (params.nRanks) { + case 2: + LAUNCH_ALLREDUCE_KERNEL(2); + break; + case 4: + LAUNCH_ALLREDUCE_KERNEL(4); + break; + case 8: + LAUNCH_ALLREDUCE_KERNEL(8); + break; + case 16: + LAUNCH_ALLREDUCE_KERNEL(16); + break; + case 32: + LAUNCH_ALLREDUCE_KERNEL(32); + break; + case 64: + LAUNCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL TwoShot RMSNorm: unsupported hidden_dim " + - std::to_string(params.hidden_dim) + - ". Supported sizes: {2048, 4096, 5120, 7168, 8192}"); + FLASHINFER_ERROR("[MNNVL AllReduceTwoShot] Unsupported world_size" + + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } -} +#undef LAUNCH_ALLREDUCE_KERNEL + + // Launch the rmsnorm lamport kernel if fusion is enabled + if (params.rmsNormFusion) { + auto gridConfig = adjustGridConfig(numTokens, tokenDim, numEltsPerThread); + int rnBlockSize = std::get<0>(gridConfig); + int rnClusterSize = std::get<1>(gridConfig); + int rnLoadsPerThread = std::get<2>(gridConfig); + + int rnNumThreads = rnClusterSize * rnBlockSize; + dim3 rnGrid(numTokens, rnClusterSize, 1); + cudaLaunchConfig_t rnConfig; + cudaLaunchAttribute rnAttrs[2]; + rnConfig.stream = params.stream; + rnConfig.gridDim = rnGrid; + rnConfig.blockDim = rnBlockSize; + rnConfig.attrs = rnAttrs; + rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + rnAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#ifndef DISABLE_CGA + rnAttrs[1].id = cudaLaunchAttributeClusterDimension; + rnAttrs[1].val.clusterDim.x = 1; + rnAttrs[1].val.clusterDim.y = rnClusterSize; + rnAttrs[1].val.clusterDim.z = 1; + rnConfig.numAttrs = 2; +#else + rnConfig.numAttrs = 1; +#endif + bool const rnUseCGA = rnClusterSize > 1; + int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); + int const iters = dimPadded / rnNumThreads; + + size_t const smemSize = 3 * rnBlockSize * iters * getDTypeSize(params.dType); + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " + "cluster_size: %d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, rnClusterSize, rnBlockSize, rnClusterSize, rnLoadsPerThread, + ceil_div(tokenDim, numEltsPerThread)); + +#define RUN_RMSNORM_KERNEL(LOADS_PER_THREAD) \ + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(&rmsNormLamport, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + smemSize)); \ + rnConfig.dynamicSmemBytes = smemSize; \ + FLASHINFER_CUDA_CALL( \ + cudaLaunchKernelEx(&rnConfig, &rmsNormLamport, residualOut, output, \ + bufferInput, gamma, static_cast(params.epsilon), residualIn, \ + numTokens, tokenDim, params.nRanks, params.bufferFlags)); + + T* residualOut = reinterpret_cast(params.residualOut); + T* output = reinterpret_cast(params.output); + T* bufferInput = reinterpret_cast(params.bufferPtrLocal); + T const* gamma = reinterpret_cast(params.gamma); + T const* residualIn = reinterpret_cast(params.residualIn); + if (rnUseCGA) { + RUN_RMSNORM_KERNEL(1); + } else { + switch (rnLoadsPerThread) { + case 1: + RUN_RMSNORM_KERNEL(1); + break; + case 2: + RUN_RMSNORM_KERNEL(2); + break; + case 3: + RUN_RMSNORM_KERNEL(3); + break; + case 4: + RUN_RMSNORM_KERNEL(4); + break; + case 5: + RUN_RMSNORM_KERNEL(5); + break; + case 6: + RUN_RMSNORM_KERNEL(6); + break; + case 7: + RUN_RMSNORM_KERNEL(7); + break; + case 8: + RUN_RMSNORM_KERNEL(8); + break; + default: + FLASHINFER_ERROR("[MNNVL AllReduceTwoShotRMSNorm] Unsupported loads_per_thread" + + std::to_string(rnLoadsPerThread) + + ". Supported sizes: {1, 2, 3, 4, 5, 6, 7, 8}"); + return cudaErrorInvalidValue; + } // switch (rnLoadsPerThread) + } // if (rnUseCGA) +#undef RUN_RMSNORM_KERNEL + + } // if (params.rmsNormFusion) + return cudaSuccess; +} } // namespace trtllm_mnnvl_allreduce } // namespace flashinfer diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 0471bd1081..20c19a0eae 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -335,6 +335,18 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +inline int GetCudaMultiProcessorCount() { + static int sm_count = 0; + if (sm_count == 0) { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_id); + sm_count = device_prop.multiProcessorCount; + } + return sm_count; +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); From 1230273a478dc6e194891e9345cf27ca6b2dcb58 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 18 Nov 2025 17:13:28 -0800 Subject: [PATCH 02/14] Initial python interface. Need adjustment. --- csrc/trtllm_mnnvl_allreduce.cu | 8 +- flashinfer/comm/mnnvl.py | 18 +- flashinfer/comm/trtllm_mnnvl_ar.py | 418 ++++++++++++++--------------- 3 files changed, 230 insertions(+), 214 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 05a1684aa0..ad23037ff3 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -33,7 +33,8 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, TensorView output, Optional residual_out, - Optional gamma, Optional epsilon) { + Optional residual_in, Optional gamma, + Optional epsilon) { cudaSetDevice(input.device().device_id); auto stream = get_stream(input.device()); @@ -82,9 +83,8 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt // input data params.input = const_cast(input.data_ptr()); - params.residualIn = residual_out.has_value() - ? const_cast(residual_out.value().data_ptr()) - : nullptr; + params.residualIn = + residual_in.has_value() ? const_cast(residual_in.value().data_ptr()) : nullptr; params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 2d280a68e8..a1a8c58d02 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -1005,7 +1005,7 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_mc_buffer( + def get_multicast_buffer( self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 ) -> torch.Tensor: """ @@ -1019,12 +1019,28 @@ def get_mc_buffer( Returns: A PyTorch tensor wrapping the multicast buffer section """ + + # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. + raise NotImplementedError("Not implemented yet") + + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: + """ + Returns a PyTorch tensor view of the unicast buffer portion. + """ + + # TODO: How can I warp a raw pointer to a tensor in python level? raise NotImplementedError("Not implemented yet") def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() + def get_unicast_ptr(self, rank: int) -> int: + """Get the raw unicast pointer to a given rank""" + return self.mcast_device_memory.get_unicast_ptr(rank) + def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 84a9c150de..f26a37d069 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -5,9 +5,10 @@ import functools import math -import os +import logging from types import SimpleNamespace -from typing import Optional, Tuple +from typing import Optional +from enum import Enum import torch @@ -25,238 +26,241 @@ def mpi_barrier(): MPI.COMM_WORLD.Barrier() +class MNNVLAllreduceFusionStrategy(Enum): + ONESHOT = 0 + TWOSHOT = 1 + AUTO = 99 + + @staticmethod + def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: + elem_size = torch.tensor([], dtype=dtype).element_size() + return num_tokens * hidden_dim * tp_size * elem_size <= kMNNVLOneShotThreshold + + +# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size +kMNNVLOneShotThreshold = 64 * 1024 * 8 * 2 + + +class MNNVLAllreduceFusionWorkspace: + NUM_LAMPORT_BUFFERS = 3 + + def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None): + """ + Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. + + Args: + mapping: Mapping configuration containing rank info + buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. + """ + if buffer_size_in_bytes is None: + # Default to 16MB workspace size if not provided + buffer_size_in_bytes = 16 * (1024**2) + else: + # Round up to the nearest multiple of 8MB + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2)) * (8 * (1024**2))) + + if buffer_size_in_bytes > (2**32 - 1): + raise ValueError( + f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." + ) + + self.buffer_size_bytes = buffer_size_in_bytes + self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + self.rank = mapping.tp_rank + self.tp_size = mapping.tp_size + logging.debug( + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes." + ) + self.mcast_buffer_handle = McastGPUBuffer( + self.workspace_size_bytes, + mapping.tp_size, + mapping.tp_rank, + torch.device("cuda", mapping.local_rank), + mapping.is_multi_node(), + ) + + # We use FP32 for sentinel value regardless of the real dtype + self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) + # Wait until the initialization is done + torch.cuda.synchronize() + # FIXME: We are assuming using the COMM_WORLD. + mpi_barrier() + + # This is a buffer to maintain the state of this allreduce Op + # Should have the same lifetime with self._buffer + # The flag should be binded to each buffer allocation + # Layout: [cur idx, dirty idx, bytes per buffer, dirty num stages, numBytesToClear[4], access count ptr] + num_bytes_to_clear = [0] * 4 + self.buffer_flags = torch.tensor( + [0, 2, self.buffer_size_bytes, 0, *num_bytes_to_clear, 0], + dtype=torch.uint32, + device=torch.device("cuda", mapping.local_rank), + ) + + self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() + self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) + self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + + @staticmethod + def get_required_buffer_size_bytes( + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> int: + """ + Calculate the required buffer size for a given problem size. + """ + elem_size = torch.tensor([], dtype=dtype).element_size() + is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot + ): + # For one-shot, each rank needs to store num_tokens * tp_size tokens + buffer_size = num_tokens * hidden_dim * tp_size * elem_size + else: + # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. + # 2 Stage is required for the two-shot allreduce. + buffer_size = 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + return buffer_size + + @functools.cache def get_trtllm_mnnvl_comm_module(): module = gen_trtllm_mnnvl_comm_module().build_and_load() @register_custom_op( - "flashinfer::trtllm_mnnvl_all_reduce", + "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ "inp", "multicast_buffer_ptr", "buffer_ptrs_dev", - "buffer_mnnvl", + "buffer_ptr_local", "buffer_flags_mnnvl", "nranks", "rank", - "wait_for_results", + "rmsnorm_fusion", "launch_with_pdl", + "use_oneshot", "out", + "residual_out", + "residual_in", + "gamma", + "epsilon", ], ) - def trtllm_mnnvl_all_reduce( + def trtllm_mnnvl_allreduce_fusion( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer - buffer_mnnvl: torch.Tensor, + buffer_ptr_local: int, # Pointer address as integer buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, - wait_for_results: bool, + rmsnorm_fusion: bool, launch_with_pdl: bool, + use_oneshot: bool, out: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + epsilon: Optional[float], ) -> None: - module.trtllm_mnnvl_all_reduce( + """ + Perform a multi-node NVLink all-reduce operation with fusion. + Args: + inp: Input tensor + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer + buffer_ptr_local: Pointer to local buffer as an integer + buffer_flags_mnnvl: Buffer flags tensor for synchronization + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + rmsnorm_fusion: Whether to perform RMSNorm fusion + launch_with_pdl: Whether to launch with PDL + use_oneshot: Whether to use one-shot (true) or two-shot (false) + outp: Output tensor + residual_out: Residual output tensor (if rmsnorm) + gamma: Gamma tensor (if rmsnorm) + epsilon: Epsilon value (if rmsnorm) + """ + module.trtllm_mnnvl_allreduce_fusion( inp, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_mnnvl, + buffer_ptr_local, buffer_flags_mnnvl, nranks, rank, - wait_for_results, + rmsnorm_fusion, launch_with_pdl, + use_oneshot, out, - ) - - @register_custom_op( - "flashinfer::trtllm_mnnvl_rmsnorm", - mutates_args=[ - "mcast_buffer_input", - "prenorm_output", - "normed_output", - "gamma", - "epsilon", - "residual", - "buffer_flags", - "launch_with_pdl", - ], - ) - def trtllm_mnnvl_rmsnorm( - mcast_buffer_input: int, - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - buffer_flags: torch.Tensor, - launch_with_pdl: bool, - ) -> None: - """Performs MNNVL TwoShot RMSNorm on the communication buffer. - - Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - mcast_buffer_input: Input tensor - gamma: The gamma parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add - buffer_flags: Buffer flags for synchronization - launch_with_pdl: Whether to launch with PDL - """ - return module.trtllm_mnnvl_rmsnorm( - mcast_buffer_input, - prenorm_output, - normed_output, + residual_out, + residual_in, gamma, epsilon, - residual, - buffer_flags, - launch_with_pdl, ) return SimpleNamespace( - trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, - trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm, - ) - - -def get_allreduce_mnnvl_workspace( - mapping: Mapping, - dtype: torch.dtype, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, - buffer_size_in_bytes: Optional[int] = None, -) -> Tuple[McastGPUBuffer, torch.Tensor, int]: - """Get workspace buffers needed for multi-node NVLink all-reduce operation. - - This function allocates and initializes the workspace buffers required for performing - multi-node NVLink all-reduce operations. It creates: - 1. A multicast GPU buffer for communication between nodes - 2. A flags tensor to track buffer state - 3. Maximum number of elements that can fit in the buffer - - The buffer size is calculated to efficiently handle common hidden dimensions - (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. - - Args: - mapping: Tensor parallel mapping configuration containing rank info - dtype: Data type of the tensors being reduced - comm: Optional communication backend for multi-node synchronization - buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens - - Returns: - Tuple containing: - - McastGPUBuffer: Multicast buffer for inter-node communication - - torch.Tensor: Buffer flags tensor tracking state - - int: Maximum number of elements that can fit in buffer - """ - force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - stride = 3 * 2 * dtype.itemsize - # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 - # max_num_elements must be a multiple of 286720 - lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = ( - buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 - ) - buffer_size_in_bytes = math.ceil( - TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) - ) * (lcm_hidden_dim * stride) - max_num_elements = buffer_size_in_bytes // stride - - mcast_buffer = McastGPUBuffer( - buffer_size_in_bytes, - mapping.tp_size, - mapping.tp_rank, - torch.device("cuda", mapping.local_rank), - mapping.is_multi_node() or force_mn, - comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, - ) - - # Initialize the unicast buffer with -0.0 - mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) - - # CPU barrier since we assume this should not be called in cuda graph - torch.cuda.synchronize() - if comm_backend_for_handle_transfer is None: - mpi_barrier() - else: - comm_backend_for_handle_transfer.barrier() - - # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] - buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", mapping.local_rank), - ) - - return ( - mcast_buffer, - buffer_flags, - max_num_elements, + trtllm_mnnvl_allreduce_fusion=trtllm_mnnvl_allreduce_fusion, ) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, - nranks: int, - rank: int, - wait_for_results: bool, + workspace: MNNVLAllreduceFusionWorkspace, launch_with_pdl: bool, out: Optional[torch.Tensor] = None, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> None: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) technology to efficiently combine tensors across multiple GPUs and nodes. - There are 3 steps: - 1. scatter each GPU's input shard to the right unicast buffer - 2. perform all-reduce on each GPU - 3. broadcast the result to all GPUs + There are 2 variants: One-shot and Two-shot: + - One-shot: Each rank stores local shard to all other ranks. Each ranks will receive all shards at the end of the communication round and perfom local reduction. Suitable for small data size and is optimized for low latency. + - Two-shot: There will be 3 steps: + 1. Scatter each GPU's input shard to other ranks. Each rank will received all shards of a slice of tokens. + 2. Each rank perform reduction on the local tokens. + 3. Each rank broadcast the result to all ranks. + Suitable for large data size and is optimized for balancing throughput and latency. Args: - inp: Local Input Shard - multicast_buffer_ptr: Pointer to the multicast buffer as an integer - buffer_ptrs_dev: Pointer to device buffer pointers as an integer - buffer_M: Maximum number of elements // hidden_dim - buffer_flags_mnnvl: Tensor containing buffer state flags - nranks: Total number of ranks participating in the all-reduce - rank: Current process rank - wait_for_results: If True, store the result to out - launch_with_pdl: If True, launch using Programmatic Dependent Launch - [Optional] out: Output tensor to store the result (required if wait_for_results is True) - + inp: Local Input Shard [num_tokens, hidden_dim] + workspace: MNNVLAllreduceFusionWorkspace + launch_with_pdl: Whether to launch with PDL + out: Output tensor to store the result + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. """ if len(inp.shape) != 2: - raise ValueError( - f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." - ) - - if inp.shape[0] > buffer_M: - raise ValueError( - f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." - ) + raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") module = get_trtllm_mnnvl_comm_module() - module.trtllm_mnnvl_all_reduce( + + use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO + and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, inp.shape[0], inp.shape[1], inp.dtype) + ) + module.trtllm_mnnvl_allreduce_fusion( inp, - multicast_buffer_ptr, - int(buffer_ptrs_dev), - buffer_M, - buffer_flags_mnnvl, - nranks, - rank, - wait_for_results, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + False, # No RMSNorm Fusion launch_with_pdl, + use_oneshot, out, + None, + None, + None, + None, ) @@ -264,19 +268,14 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: torch.Tensor, normed_output: torch.Tensor, shard_input: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - unicast_ptr: int, # Local unicast buffer pointer - buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, - nranks: int, - rank: int, + workspace: MNNVLAllreduceFusionWorkspace, gamma: torch.Tensor, epsilon: float, residual: torch.Tensor, launch_with_pdl: bool, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> None: - """Performs MNNVL TwoShot Allreduce + RMSNorm. + """Performs MNNVL Allreduce + RMSNorm. This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. @@ -286,43 +285,44 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: Output tensor for prenorm results normed_output: Output tensor for normalized results shard_input: Input tensor shard - multicast_buffer_ptr: Pointer address as integer for multicast buffer - buffer_ptrs_dev: Pointer address as integer for device buffer pointers - unicast_ptr: Pointer address as integer for unicast buffer - buffer_M: Maximum number of elements // hidden_dim - buffer_flags_mnnvl: Buffer flags for synchronization - nranks: Number of ranks in the tensor parallel group - rank: Current rank in the tensor parallel group + workspace: MNNVLAllreduceFusionWorkspace gamma: The gamma (norm weight) parameter for RMSNorm epsilon: The epsilon parameter for RMSNorm residual: The residual tensor to add launch_with_pdl: Whether to launch with PDL """ - # allreduce_result = Σ(shard_input across all ranks) - trtllm_mnnvl_all_reduce( - shard_input, - multicast_buffer_ptr, - buffer_ptrs_dev, - buffer_M, - buffer_flags_mnnvl, - nranks, - rank, - False, # No need to wait to write AR results here as we are not writing them - launch_with_pdl, - None, # out parameter - None since wait_for_results=False + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO + and MNNVLAllreduceFusionStrategy.is_one_shot( + workspace.tp_size, + shard_input.shape[0], + shard_input.shape[1], + shard_input.dtype, + ) ) - # prenorm_output = AllReduce(shard_input) + residual - # rms = sqrt(mean(prenorm_output²) + epsilon) - # normed_output = (prenorm_output / rms) * gamma - get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( - unicast_ptr, - prenorm_output, + module.trtllm_mnnvl_allreduce_fusion( + shard_input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + True, # RMSNorm Fusion + launch_with_pdl, + use_oneshot, normed_output, + prenorm_output, + residual, gamma, epsilon, - residual, - buffer_flags_mnnvl, - launch_with_pdl, ) From 874c228f4ae74f20932dcd4e7751307c163bc57d Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 18 Nov 2025 17:40:17 -0800 Subject: [PATCH 03/14] Refactor the interface. --- flashinfer/comm/trtllm_mnnvl_ar.py | 126 ++++++++++++++++++----------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index f26a37d069..839e03411c 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -7,7 +7,7 @@ import math import logging from types import SimpleNamespace -from typing import Optional +from typing import Optional, Tuple from enum import Enum import torch @@ -34,11 +34,12 @@ class MNNVLAllreduceFusionStrategy(Enum): @staticmethod def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: elem_size = torch.tensor([], dtype=dtype).element_size() - return num_tokens * hidden_dim * tp_size * elem_size <= kMNNVLOneShotThreshold + return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size -kMNNVLOneShotThreshold = 64 * 1024 * 8 * 2 +# TODO(Refactor): Consider moving this to a configuration class or file +MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 class MNNVLAllreduceFusionWorkspace: @@ -133,7 +134,7 @@ def get_trtllm_mnnvl_comm_module(): @register_custom_op( "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ - "inp", + "input", "multicast_buffer_ptr", "buffer_ptrs_dev", "buffer_ptr_local", @@ -143,7 +144,7 @@ def get_trtllm_mnnvl_comm_module(): "rmsnorm_fusion", "launch_with_pdl", "use_oneshot", - "out", + "output", "residual_out", "residual_in", "gamma", @@ -151,7 +152,7 @@ def get_trtllm_mnnvl_comm_module(): ], ) def trtllm_mnnvl_allreduce_fusion( - inp: torch.Tensor, + input: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer buffer_ptr_local: int, # Pointer address as integer @@ -161,7 +162,7 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion: bool, launch_with_pdl: bool, use_oneshot: bool, - out: Optional[torch.Tensor], + output: torch.Tensor, residual_out: Optional[torch.Tensor], residual_in: Optional[torch.Tensor], gamma: Optional[torch.Tensor], @@ -170,7 +171,7 @@ def trtllm_mnnvl_allreduce_fusion( """ Perform a multi-node NVLink all-reduce operation with fusion. Args: - inp: Input tensor + input: Input tensor multicast_buffer_ptr: Pointer to the multicast buffer as an integer buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer buffer_ptr_local: Pointer to local buffer as an integer @@ -180,13 +181,13 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion: Whether to perform RMSNorm fusion launch_with_pdl: Whether to launch with PDL use_oneshot: Whether to use one-shot (true) or two-shot (false) - outp: Output tensor + output: Output tensor residual_out: Residual output tensor (if rmsnorm) gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ module.trtllm_mnnvl_allreduce_fusion( - inp, + input, multicast_buffer_ptr, buffer_ptrs_dev, buffer_ptr_local, @@ -196,7 +197,7 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion, launch_with_pdl, use_oneshot, - out, + output, residual_out, residual_in, gamma, @@ -208,13 +209,13 @@ def trtllm_mnnvl_allreduce_fusion( ) -def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, +def trtllm_mnnvl_allreduce( + input: torch.Tensor, workspace: MNNVLAllreduceFusionWorkspace, launch_with_pdl: bool, - out: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, -) -> None: +) -> torch.Tensor: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) @@ -229,24 +230,32 @@ def trtllm_mnnvl_all_reduce( Suitable for large data size and is optimized for balancing throughput and latency. Args: - inp: Local Input Shard [num_tokens, hidden_dim] + input: Local Input Shard [num_tokens, hidden_dim] workspace: MNNVLAllreduceFusionWorkspace launch_with_pdl: Whether to launch with PDL - out: Output tensor to store the result + output: Output tensor to store the result, empty tensor will be created if not provided. strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Reduced tensor [num_tokens, hidden_dim] """ - if len(inp.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") + # Check ndims here as the shape check is done in the kernel launch code. + if len(input.shape) != 2: + raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") module = get_trtllm_mnnvl_comm_module() use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, inp.shape[0], inp.shape[1], inp.dtype) + and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) ) module.trtllm_mnnvl_allreduce_fusion( - inp, + input, workspace.mc_ptr, workspace.uc_ptrs_dev, workspace.uc_ptr_local, @@ -256,7 +265,7 @@ def trtllm_mnnvl_all_reduce( False, # No RMSNorm Fusion launch_with_pdl, use_oneshot, - out, + output, None, None, None, @@ -265,36 +274,60 @@ def trtllm_mnnvl_all_reduce( def trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - shard_input: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + input: torch.Tensor, + residual_in: torch.Tensor, gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - launch_with_pdl: bool, + workspace: MNNVLAllreduceFusionWorkspace, + epsilon: Optional[float] = None, + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: """Performs MNNVL Allreduce + RMSNorm. - This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank. Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - shard_input: Input tensor shard + input: Input tensor [num_tokens, hidden_dim] + residual_in: Residual input tensor [num_tokens, hidden_dim] + gamma: Gamma tensor [hidden_dim] workspace: MNNVLAllreduceFusionWorkspace - gamma: The gamma (norm weight) parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add + epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. + output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. + residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. launch_with_pdl: Whether to launch with PDL + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Normalized tensor [num_tokens, hidden_dim] + residual_out: Residual output tensor [num_tokens, hidden_dim] """ - if len(shard_input.shape) != 2: + + if epsilon is None: + epsilon = torch.finfo(input.dtype).eps + + if len(input.shape) != 2: + raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + if len(residual_in.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." + ) + if gamma.numel() != input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements." + ) + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + if residual_out is None: + residual_out = torch.empty_like(residual_in) + elif len(residual_out.shape) != 2: raise ValueError( - f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}." ) module = get_trtllm_mnnvl_comm_module() @@ -303,14 +336,14 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( strategy == MNNVLAllreduceFusionStrategy.AUTO and MNNVLAllreduceFusionStrategy.is_one_shot( workspace.tp_size, - shard_input.shape[0], - shard_input.shape[1], - shard_input.dtype, + input.shape[0], + input.shape[1], + input.dtype, ) ) module.trtllm_mnnvl_allreduce_fusion( - shard_input, + input, workspace.mc_ptr, workspace.uc_ptrs_dev, workspace.uc_ptr_local, @@ -320,9 +353,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( True, # RMSNorm Fusion launch_with_pdl, use_oneshot, - normed_output, - prenorm_output, - residual, + output, + residual_out, + residual_in, gamma, epsilon, ) + return output, residual_out From 17a129207dec879eb1236d4195c641430285cbed Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 16:40:48 -0800 Subject: [PATCH 04/14] Staging changes, result is wrong. --- csrc/trtllm_mnnvl_allreduce.cu | 10 +- flashinfer/comm/mnnvl.py | 415 ++++++++++-------- flashinfer/comm/trtllm_mnnvl_ar.py | 9 +- flashinfer/jit/comm.py | 1 + .../comm/trtllm_mnnvl_allreduce.cuh | 15 +- tests/comm/test_trtllm_mnnvl_allreduce.py | 265 +++++------ 6 files changed, 363 insertions(+), 352 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index ad23037ff3..c7215a4241 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -58,12 +58,14 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; if (rmsnorm_fusion) { - TVM_FFI_ICHECK(residual_out.size(0) == num_tokens && residual_out.size(1) == token_dim) + TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && + residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) - << ") but got (" << residual_out.size(0) << ", " << residual_out.size(1) << ")"; - TVM_FFI_ICHECK(gamma.size(0) == token_dim) + << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) + << ")"; + TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" - << gamma.size(0) << ")"; + << gamma.value().size(0) << ")"; } // Create the parameters struct diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index a1a8c58d02..520f6e4880 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,6 +16,12 @@ import ctypes import logging import os +import socket +import array +import random + +import contextlib + from abc import ABC, abstractmethod from dataclasses import dataclass import platform @@ -35,8 +41,7 @@ from cuda import cuda except ImportError as e: raise ImportError( - "Could not import the 'cuda' module. " - "Please install cuda-python that matches your CUDA version." + "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." ) from e from ..cuda_utils import checkCudaErrors @@ -57,9 +62,7 @@ def round_up(val: int, gran: int) -> int: return (val + gran - 1) & ~(gran - 1) -def create_tensor_from_cuda_memory( - ptr: int, shape: tuple, dtype: torch.dtype, device_id: int -) -> torch.Tensor: +def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, device_id: int) -> torch.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. @@ -81,9 +84,7 @@ def create_tensor_from_cuda_memory( element_size = torch.tensor([], dtype=dtype).element_size() # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) - capsule_wrapper = create_dlpack_capsule( - ptr, element_size, element_size, numel, dtype, device_id - ) + capsule_wrapper = create_dlpack_capsule(ptr, element_size, element_size, numel, dtype, device_id) # Convert to tensor and reshape tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) @@ -123,24 +124,25 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: return False -def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ if not host_ptr_array: return None + for addr in host_ptr_array: + print(f"DEBUG: ptr_array: 0x{addr:x}") + ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) - checkCudaErrors( - cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) - ) + checkCudaErrors(cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)) # c_array should be freed by GC - return device_ptr + return int(device_ptr) class CommBackend(ABC): @@ -155,6 +157,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def bcast(self, data: Any, root: int) -> Any: ... + @abstractmethod def barrier(self) -> None: ... @@ -212,6 +217,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def bcast(self, data: Any, root: int) -> Any: + return self._mpicomm.bcast(data, root) + def barrier(self): self._mpicomm.Barrier() @@ -287,18 +295,14 @@ def initialize(): @staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) MnnvlMemory.comm = comm # type: ignore[assignment] @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm - comm = MpiComm().Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + comm = MpiComm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) MnnvlMemory.comm = comm return comm @@ -314,9 +318,7 @@ def get_allocation_prop(dev_id: int): arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch if is_on_aarch64: - allocation_prop.requestedHandleTypes = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ) + allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC else: allocation_prop.requestedHandleTypes = ( cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR @@ -332,27 +334,19 @@ def get_allocation_granularity(dev_id: int): option = cuda.CUmemAllocationGranularity_flags( cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ) - granularity = checkCudaErrors( - cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option) - ) + granularity = checkCudaErrors(cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)) MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): - page_count = ( - size + MnnvlMemory.fabric_page_size - 1 - ) // MnnvlMemory.fabric_page_size + page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size - logging.info( - f"[MnnvlMemory] creating address with stride={current_rank_stride}" - ) + logging.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}") comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size - ptr = checkCudaErrors( - cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0) - ) + ptr = checkCudaErrors(cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)) MnnvlMemory.current_start_address = int(ptr) MnnvlMemory.current_rank_stride = current_rank_stride MnnvlMemory.current_mem_offset = 0 @@ -363,44 +357,29 @@ def open_mnnvl_memory(mapping: Mapping, size: int): dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id - assert dev_id == MnnvlMemory.dev_id, ( - f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" - ) + assert ( + dev_id == MnnvlMemory.dev_id + ), f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size - assert all(x == size for x in all_rank_allocate_sizes), ( - "Not all rank allocating same size." - ) + assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size." granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if ( - MnnvlMemory.current_mem_offset + aligned_size - > MnnvlMemory.current_rank_stride - ): + if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride: MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) - assert ( - MnnvlMemory.current_mem_offset + aligned_size - <= MnnvlMemory.current_rank_stride - ) + assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) - allocated_mem_handle = checkCudaErrors( - cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) - ) + allocated_mem_handle = checkCudaErrors(cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)) exported_fabric_handle = checkCudaErrors( - cuda.cuMemExportToShareableHandle( - allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 - ) + cuda.cuMemExportToShareableHandle(allocated_mem_handle, allocation_prop.requestedHandleTypes, 0) ) - if ( - allocation_prop.requestedHandleTypes - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): + if allocation_prop.requestedHandleTypes == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: all_handles_data = comm.allgather(exported_fabric_handle.data) else: all_handles_data = comm.allgather(exported_fabric_handle) @@ -414,9 +393,7 @@ def open_mnnvl_memory(mapping: Mapping, size: int): pidfd = syscall(SYS_pidfd_open, pid, 0) if pidfd < 0: err = ctypes.get_errno() - raise RuntimeError( - f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" - ) + raise RuntimeError(f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}") pidfds.append(pidfd) remote_fds = [] @@ -431,9 +408,7 @@ def open_mnnvl_memory(mapping: Mapping, size: int): "to your docker run command." ) else: - error_msg += ( - " This may be due to kernel version (requires Linux 5.6+)." - ) + error_msg += " This may be due to kernel version (requires Linux 5.6+)." raise RuntimeError(error_msg) remote_fds.append(remote_fd) @@ -449,27 +424,19 @@ def open_mnnvl_memory(mapping: Mapping, size: int): for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( - MnnvlMemory.current_start_address - + MnnvlMemory.current_rank_stride * i - + MnnvlMemory.current_mem_offset + MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset ) if i == comm_rank: # Local memory mapping mem_handles[i] = allocated_mem_handle - checkCudaErrors( - cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0) - ) + checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0)) else: # Fabric memory mapping imported_mem_handle = checkCudaErrors( - cuda.cuMemImportFromShareableHandle( - remote_handle_data, allocation_prop.requestedHandleTypes - ) + cuda.cuMemImportFromShareableHandle(remote_handle_data, allocation_prop.requestedHandleTypes) ) mem_handles[i] = imported_mem_handle - checkCudaErrors( - cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) - ) + checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0)) checkCudaErrors(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)) @@ -526,20 +493,14 @@ def support_nvlink(need_all_up: bool = True): available_links = 0 for link_idx in range(link_count): try: - if pynvml.nvmlDeviceGetNvLinkCapability( - handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED - ): + if pynvml.nvmlDeviceGetNvLinkCapability(handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED): available_links += 1 is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx) if is_active: active_links += 1 except pynvml.NVMLError_NotSupported: continue - return ( - active_links == available_links and available_links > 0 - if need_all_up - else available_links > 0 - ) + return active_links == available_links and available_links > 0 if need_all_up else available_links > 0 @staticmethod def supports_mnnvl() -> bool: @@ -551,6 +512,103 @@ def supports_mnnvl() -> bool: return support_nvlink_and_all_up +# The helper class for passing the FD handle over the socket. +class IpcSocket: + """Unix Domain Socket for IPC file descriptor passing""" + + def __init__(self, rank: int, op_id: int, use_abstract=True): + """ + Initialize IPC socket + + Args: + rank: Process rank + op_id: Unique operation ID (hash) + use_abstract: Use Linux abstract socket namespace + """ + self.rank = rank + self.op_id = op_id + self.use_abstract = use_abstract + + # Create Unix domain socket (DGRAM for compatibility with C code) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + + # Create unique socket name + socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}" + + if use_abstract: + # Linux abstract socket: prepend null byte + self.socket_path = "\0" + socket_name + else: + self.socket_path = socket_name + # Remove existing socket file if it exists + with contextlib.suppress(FileNotFoundError): + os.unlink(socket_name) + + # Bind socket + self.sock.bind(self.socket_path) + + def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None): + """ + Send a file descriptor to another process + + Args: + fd: File descriptor to send + dest_rank: Destination process rank + dest_op_id: Destination operation ID + """ + # Construct destination socket path + dest_op_id = dest_op_id or self.op_id + dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}" + + if self.use_abstract: + dest_path = "\0" + dest_socket_name + else: + dest_path = dest_socket_name + + # Prepare message with file descriptor + # Send dummy byte as data (required) + dummy_data = b"\x00" + + # Pack file descriptor in ancillary data (SCM_RIGHTS) + fds = array.array("i", [fd]) + ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())] + + # Send message with file descriptor + self.sock.sendmsg([dummy_data], ancillary, 0, dest_path) + + def recv_fd(self): + """ + Receive a file descriptor from another process + + Returns: + int: Received file descriptor + """ + # Receive message with ancillary data + # Maximum size for ancillary data containing one fd + fds = array.array("i") + msg, ancdata, flags, addr = self.sock.recvmsg( + 1, + socket.CMSG_SPACE(fds.itemsize), # Buffer size for dummy data # Ancillary data size + ) + + # Extract file descriptor from ancillary data + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds = array.array("i") + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + return fds[0] + + raise RuntimeError("No file descriptor received") + + def close(self): + """Close the socket""" + self.sock.close() + if not self.use_abstract and self.socket_path: + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + + +# TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -562,6 +620,7 @@ def __init__( device_idx: int, is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -588,6 +647,7 @@ def __init__( self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 + self.comm_backend = comm_backend_for_handle_transfer or MPIBackend() # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr @@ -596,9 +656,9 @@ def __init__( self.signal_pads_dev = 0 # std::vector mSignalPadsDev self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle - self.uc_handles: List[ - int - ] = [] # std::vector mUcHandles + self.uc_handles: List[int] = [] # std::vector mUcHandles + + self._shareable_handle_type = None # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 @@ -612,9 +672,7 @@ def __init__( ) ) if multicast_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support multicasting." - ) + raise RuntimeError("[McastDeviceMemory] Device does not support multicasting.") # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) @@ -634,23 +692,21 @@ def __init__( ) ) if fabric_handle_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support fabric handle." - ) - - self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) + raise RuntimeError("[McastDeviceMemory] Device does not support fabric handle.") + # Use fabric handle for multi-node NVLS + self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC else: - # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem - raise NotImplementedError("Single-node NVLS allocation not implemented yet") + self._init_ipc_socket() + # Use NVLink handle for single-node NVLS + self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads self.signal_pads = [0] * self.group_size for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: - checkCudaErrors( - cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) - ) + checkCudaErrors(cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)) # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) @@ -693,29 +749,19 @@ def __del__(self): checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: - checkCudaErrors( - cuda.cuMemUnmap( - self.uc_ptrs[rank], self.allocation_size - ) - ) + checkCudaErrors(cuda.cuMemUnmap(self.uc_ptrs[rank], self.allocation_size)) except Exception as e: - print( - f"Destructor: Failed to release UC handle for rank {rank}: {e}" - ) + print(f"Destructor: Failed to release UC handle for rank {rank}: {e}") # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: - checkCudaErrors( - cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) - ) + checkCudaErrors(cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size)) # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) - checkCudaErrors( - cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) - ) + checkCudaErrors(cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size)) checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") @@ -760,9 +806,16 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem( - self, buf_size: int, comm_backend_for_handle_transfer: Any = None - ): + def _init_ipc_socket(self): + if self.group_rank == 0: + # Gnerate the opId + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm_backend.bcast(opId, root=0) + self._ipc_socket = IpcSocket(self.group_rank, opId) + + def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -770,25 +823,16 @@ def _alloc_mn_mcast_mem( current_device = checkCudaErrors(cuda.cuCtxGetDevice()) if int(current_device) != self.device_idx: - print( - f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" - ) + print(f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}") except Exception as e: print(f"Error checking CUDA context: {e}") - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - # Set up allocation properties - handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + # Set up allocation properties allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = handle_type + allocation_prop.requestedHandleTypes = self._shareable_handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() - allocation_prop.location.type = ( - cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - ) + allocation_prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE allocation_prop.location.id = self.device_idx allocation_prop.allocFlags.gpuDirectRDMACapable = 1 @@ -802,15 +846,13 @@ def _alloc_mn_mcast_mem( ) # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); - self.allocation_size = round_up( - buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity - ) + self.allocation_size = round_up(buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity) # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = handle_type + mc_prop.handleTypes = self._shareable_handle_type # Get multicast granularity mc_granularity = checkCudaErrors( @@ -826,30 +868,43 @@ def _alloc_mn_mcast_mem( self.uc_handles = [0] * self.group_size # Allocate local GPU memory - self.uc_handles[self.group_rank] = checkCudaErrors( - cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) - ) + self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) # Export local handle to fabric handle - my_fabric_handle = checkCudaErrors( + local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) - # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + # All-gather fabric handles + all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) + else: + # Implement the allgather logic with ipc socket + # TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help. + all_shareable_uc_handles = [None] * self.group_size + for i in range(self.group_size): + self.comm_backend.barrier() + # Send to peer at offset i + dest_rank = (self.group_rank + i) % self.group_size + self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) + # Receive from peer at offset -i + src_rank = (self.group_rank + self.group_size - i) % self.group_size + all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() cuda.cuCtxSynchronize() + print(f"[Rank {self.group_rank}] all_shareable_uc_handles: {all_shareable_uc_handles}") + # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - all_fabric_handles[p], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + all_shareable_uc_handles[p], + self._shareable_handle_type, ) ) @@ -858,29 +913,43 @@ def _alloc_mn_mcast_mem( # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - # Export multicast handle - mc_fabric_handle = checkCudaErrors( + # Export multicast handle, there's only one handle for the entire group + shareable_mc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) else: - mc_fabric_handle = None - - # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) + shareable_mc_handle = None + if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + # Broadcast multicast handle + shareable_mc_handle = self.comm_backend.bcast( + shareable_mc_handle.data if shareable_mc_handle else None, root=0 + ) + else: + # Implement bcast logic with ipc socket + if self.group_rank == 0: + for p in range(1, self.group_size): + self.comm_backend.barrier() + self._ipc_socket.send_fd(shareable_mc_handle, p) + else: + # Other ranks receive from rank 0 + # We need to order the receive to avoid a race condition bug we encountered. If driver fixed this issue, the additional barriers used for ordering can be removed. + for _ in range(self.group_rank): + self.comm_backend.barrier() + shareable_mc_handle = self._ipc_socket.recv_fd() + for _ in range(self.group_size - self.group_rank - 1): + self.comm_backend.barrier() # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - mc_fabric_handle_data, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + shareable_mc_handle, + self._shareable_handle_type, ) ) @@ -893,9 +962,7 @@ def _alloc_mn_mcast_mem( # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) - ) + uc_base_ptr = checkCudaErrors(cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0)) self.uc_base_ptr = uc_base_ptr # Store for cleanup # Set up memory access descriptor @@ -909,27 +976,15 @@ def _alloc_mn_mcast_mem( for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors( - cuda.cuMemMap( - self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 - ) - ) + checkCudaErrors(cuda.cuMemMap(self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0)) # Set memory access permissions - checkCudaErrors( - cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) - ) + checkCudaErrors(cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1)) # Bind MC pointer - self.mc_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) - ) - checkCudaErrors( - cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) - ) - checkCudaErrors( - cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) - ) + self.mc_ptr = checkCudaErrors(cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0)) + checkCudaErrors(cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0)) + checkCudaErrors(cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1)) # Bind memory to multicast checkCudaErrors( @@ -958,9 +1013,7 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): # Calculate number of elements that fit in allocation_size num_elements = self.allocation_size // dsize - checkCudaErrors( - memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) - ) + checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) class McastGPUBuffer: @@ -1005,9 +1058,7 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_multicast_buffer( - self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 - ) -> torch.Tensor: + def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. @@ -1023,9 +1074,7 @@ def get_multicast_buffer( # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. raise NotImplementedError("Not implemented yet") - def get_unicast_buffer( - self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 - ) -> torch.Tensor: + def get_unicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: """ Returns a PyTorch tensor view of the unicast buffer portion. """ diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 839e03411c..0b5db72628 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -58,7 +58,7 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB - buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2)) * (8 * (1024**2))) + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) if buffer_size_in_bytes > (2**32 - 1): raise ValueError( @@ -186,6 +186,9 @@ def trtllm_mnnvl_allreduce_fusion( gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ + print( + f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" + ) module.trtllm_mnnvl_allreduce_fusion( input, multicast_buffer_ptr, @@ -342,6 +345,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( ) ) + print( + f"[Rank {workspace.rank}] workspace.mc_ptr: {workspace.mc_ptr}, workspace.uc_ptrs_dev: {workspace.uc_ptrs_dev}, workspace.uc_ptr_local: {workspace.uc_ptr_local}" + ) + module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 27661b1fe2..4f59c8930e 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -36,6 +36,7 @@ def gen_trtllm_mnnvl_comm_module() -> JitSpec: [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", ], + extra_cuda_cflags=["-lineinfo"], ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 9198df8775..2177cfc618 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -536,6 +536,7 @@ __global__ void __launch_bounds__(1024) T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.ctaArrive(); flag.clearDirtyLamportBuf(inputPtrs[rank], -1); return; } @@ -545,7 +546,7 @@ __global__ void __launch_bounds__(1024) val.packed = loadPacked(&shardPtr[threadOffset]); #pragma unroll for (int i = 0; i < kELTS_PER_THREAD; i++) { - if (isNegZero(val.elements[i])) val.elements[i] = toFloat(0.f); + if (isNegZero(val.elements[i])) val.elements[i] = fromFloat(0.f); } reinterpret_cast( @@ -641,7 +642,7 @@ __global__ void __launch_bounds__(1024) #pragma unroll for (int i = 0; i < kELTS_PER_THREAD; i++) { packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * - fromFloat(gamma.elements[i])); + toFloat(gamma.elements[i])); } } reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; @@ -725,18 +726,24 @@ cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) // FIXME: Do we need other world sizes? case 2: DISPATCH_ALLREDUCE_KERNEL(2); + break; case 4: DISPATCH_ALLREDUCE_KERNEL(4); + break; case 8: DISPATCH_ALLREDUCE_KERNEL(8); + break; case 16: DISPATCH_ALLREDUCE_KERNEL(16); + break; case 32: DISPATCH_ALLREDUCE_KERNEL(32); + break; case 64: DISPATCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + + FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } @@ -1145,7 +1152,7 @@ cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); int const iters = dimPadded / rnNumThreads; - size_t const smemSize = 3 * rnBlockSize * iters * getDTypeSize(params.dType); + size_t const smemSize = 3 * rnBlockSize * iters * sizeof(T); FLASHINFER_LOG_DEBUG( "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e7274c46f0..e0758c271c 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -5,6 +5,10 @@ import torch from mpi4py import MPI # Added MPI import +from flashinfer.utils import set_log_level + +set_log_level("debug") + import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import CommBackend, MpiComm @@ -24,19 +28,13 @@ def row_linear_residual_norm_fusion_forward( mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], - multicast_ptr: int, - buffer_ptrs_dev: int, - unicast_ptr: int, - max_num_elements_mnnvl: int, - buffer_flags_mnnvl: torch.Tensor, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, + workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): x = x.cuda() residual = residual.cuda() norm_weight = norm_weight.cuda() reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank if comm_backend_for_handle_transfer is None: comm = MpiComm() @@ -50,75 +48,40 @@ def func( norm_weight, eps, enable_fusion, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, + workspace, ): # For both fused and unfused cases: shape = input.shape - - assert max_num_elements_mnnvl % hidden_size == 0 - input = input.view(-1, shape[-1]) - - buffer_M = max_num_elements_mnnvl // hidden_size + use_pdl = True if enable_fusion: - use_pdl = True - - prenorm_output = torch.empty_like(residual) - normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() - trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output, - normed_output, + output, residual_out = trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( input, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - buffer_M, - buffer_flags_mnnvl, - tensor_parallel_size, - tensor_parallel_rank, + residual, norm_weight, + workspace, eps, - residual, - use_pdl, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, ) - return normed_output.view(shape), prenorm_output.view(shape) + return output.view(shape), residual_out.view(shape) else: output = torch.empty_like(input) - trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( + output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( input, - multicast_ptr, - buffer_ptrs_dev, - buffer_M, - buffer_flags_mnnvl, - tensor_parallel_size, - tensor_parallel_rank, - True, # wait_for_results - False, # launch_with_pdl - output, # Need to provide output tensor since we are writing them out. + workspace, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, ) return (output.view(shape),) - output = func( - x.clone(), - residual.clone(), - norm_weight, - eps, - fusion, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - ) + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) assert output[0].shape == reference_output[0].shape @@ -173,7 +136,8 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + if monkeypatch is not None: + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. # Get MPI info rank = MPI.COMM_WORLD.Get_rank() @@ -198,43 +162,32 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: - print( - f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" - ) - print( - f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" - ) + print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") + print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") tensor_parallel_size = world_size eps = 1e-5 - torch.manual_seed(42) + torch.manual_seed(42 + rank) # Track if this rank failed rank_failed = False failure_message = "" try: - # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list - # This workspace is sized for the maximum expected sequence length and can be reused within each list - # Each parameterized list gets its own fresh workspace allocation - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes - ) - ) - - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() - buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() - unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( - mapping.tp_rank + required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( + mapping.tp_size, + max(seq_lens), + hidden_size, + dtype, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.TWOSHOT, ) + workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) # Test each sequence length with the same workspace (reusing allocated buffers within this list) for seq_len in seq_lens: if rank == 0: - print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" - ) + print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") + print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") # Generate test data (same on all ranks due to same seed) x_full = torch.randn( @@ -242,12 +195,8 @@ def run_mnnvl_ar_full( dtype=dtype, device=torch.device("cuda"), ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") - ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") - ) + residual = torch.randn((seq_len, hidden_size), dtype=dtype, device=torch.device("cuda")) + norm_weight = torch.randn((hidden_size,), dtype=dtype, device=torch.device("cuda")) # Each rank gets its slice of the input x = x_full[rank, :, :] @@ -258,11 +207,7 @@ def run_mnnvl_ar_full( # Fused case: AllReduce + Residual Add + RMS Norm allreduce_result = torch.sum(x_full, dim=0) # AllReduce result residual_out = allreduce_result + residual # Add residual - print( - "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device - ) - ) + print("Device of residual_out:{}, norm_weight:{}".format(residual_out.device, norm_weight.device)) norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) reference_output = (norm_out, residual_out) @@ -282,24 +227,21 @@ def run_mnnvl_ar_full( mapping, fusion, reference_output, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - buffer_flags_mnnvl, + workspace, ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print( - f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" - ) + print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}") except Exception as e: rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) + import traceback + + print(traceback.format_exc()) # Gather failure status from all ranks for logging all_failures = MPI.COMM_WORLD.allgather(rank_failed) @@ -310,16 +252,16 @@ def run_mnnvl_ar_full( print(f"Test failed on ranks: {failed_ranks}") # Cleanup before re-raising - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() @@ -348,61 +290,64 @@ def test_mnnvl_allreduce_default_workspace( run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) -"""Test with explicit workspace size""" - - -@pytest.mark.parametrize( - "seq_lens", - [ - [1, 4, 180], - ], -) -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_explicit_workspace( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test MNNVL AllReduce with explicitly calculated workspace size.""" - # Calculate workspace to fit the maximum sequence length - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) - run_mnnvl_ar_full( - monkeypatch, - seq_lens, - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=explicit_workspace_bytes, - ) - - -"""Negative test: workspace too small""" - - -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096]) -def test_mnnvl_allreduce_workspace_too_small( - monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" - # Use a large sequence length that won't fit in a small workspace - seq_len = 180 - - # Create a workspace that's too small (only enough for 10 tokens) - small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - - # Expect a ValueError with a message about buffer_M being too small - with pytest.raises((ValueError, RuntimeError)) as exc_info: - run_mnnvl_ar_full( - monkeypatch, - [seq_len], - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=small_workspace_bytes, - ) - - # Verify the error message contains the expected text - assert "greater than the buffer_M" in str(exc_info.value) +if __name__ == "__main__": + run_mnnvl_ar_full(None, [15], False, torch.bfloat16, 4096) + +# """Test with explicit workspace size""" + + +# @pytest.mark.parametrize( +# "seq_lens", +# [ +# [1, 4, 180], +# ], +# ) +# @pytest.mark.parametrize("fusion", [False, True]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +# def test_mnnvl_allreduce_explicit_workspace( +# monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +# ): +# """Test MNNVL AllReduce with explicitly calculated workspace size.""" +# # Calculate workspace to fit the maximum sequence length +# # buffer shape: [3, 2, buffer_tokens, hidden_dim] +# explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) +# run_mnnvl_ar_full( +# monkeypatch, +# seq_lens, +# fusion, +# dtype, +# hidden_size, +# explicit_workspace_bytes=explicit_workspace_bytes, +# ) + + +# """Negative test: workspace too small""" + + +# @pytest.mark.parametrize("fusion", [False, True]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("hidden_size", [2048, 4096]) +# def test_mnnvl_allreduce_workspace_too_small( +# monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int +# ): +# """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" +# # Use a large sequence length that won't fit in a small workspace +# seq_len = 180 + +# # Create a workspace that's too small (only enough for 10 tokens) +# small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 + +# # Expect a ValueError with a message about buffer_M being too small +# with pytest.raises((ValueError, RuntimeError)) as exc_info: +# run_mnnvl_ar_full( +# monkeypatch, +# [seq_len], +# fusion, +# dtype, +# hidden_size, +# explicit_workspace_bytes=small_workspace_bytes, +# ) + +# # Verify the error message contains the expected text +# assert "greater than the buffer_M" in str(exc_info.value) From 4caf71aa32bec398b5b6bfcb05f2abdd93d7cfc3 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 19:32:26 -0800 Subject: [PATCH 05/14] Passing the test. --- flashinfer/comm/trtllm_mnnvl_ar.py | 2 + flashinfer/jit/comm.py | 1 - tests/comm/test_trtllm_mnnvl_allreduce.py | 179 ++++++++++------------ 3 files changed, 81 insertions(+), 101 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 0b5db72628..eae919e5e0 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -275,6 +275,8 @@ def trtllm_mnnvl_allreduce( None, ) + return output + def trtllm_mnnvl_fused_allreduce_rmsnorm( input: torch.Tensor, diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 4f59c8930e..27661b1fe2 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -36,7 +36,6 @@ def gen_trtllm_mnnvl_comm_module() -> JitSpec: [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", ], - extra_cuda_cflags=["-lineinfo"], ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e0758c271c..6b89661650 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -5,10 +5,6 @@ import torch from mpi4py import MPI # Added MPI import -from flashinfer.utils import set_log_level - -set_log_level("debug") - import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import CommBackend, MpiComm @@ -23,24 +19,21 @@ def row_linear_residual_norm_fusion_forward( residual: torch.Tensor, norm_weight: torch.Tensor, eps: float, - hidden_size: int, - dtype: torch.dtype, mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): - x = x.cuda() - residual = residual.cuda() - norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_rank = mapping.tp_rank +<<<<<<< HEAD if comm_backend_for_handle_transfer is None: comm = MpiComm() else: comm = comm_backend_for_handle_transfer comm.barrier() +======= + MPI.COMM_WORLD.barrier() +>>>>>>> bca4f5d9 (Passing the test.) def func( input, @@ -65,7 +58,7 @@ def func( workspace, eps, launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return output.view(shape), residual_out.view(shape) @@ -77,7 +70,7 @@ def func( input, workspace, launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return (output.view(shape),) @@ -118,13 +111,49 @@ def func( """Helper function to run the core MNNVL AllReduce test logic""" +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + # Communicator used for passing data between ranks + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + def run_mnnvl_ar_full( - monkeypatch, - seq_lens: list[int], - fusion: bool, - dtype: torch.dtype, - hidden_size: int, - explicit_workspace_bytes: int | None = None, + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): """Core test logic for MNNVL AllReduce operations. @@ -136,18 +165,15 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - if monkeypatch is not None: - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + comm = MPI.COMM_WORLD # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + rank = comm.Get_rank() + world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") - - # Ensure we have exactly 2 ranks for this test if world_size < 2: pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") @@ -162,10 +188,19 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: +<<<<<<< HEAD print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") tensor_parallel_size = world_size +======= + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) +>>>>>>> bca4f5d9 (Passing the test.) eps = 1e-5 torch.manual_seed(42 + rank) @@ -179,13 +214,23 @@ def run_mnnvl_ar_full( max(seq_lens), hidden_size, dtype, - trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.TWOSHOT, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) - # Test each sequence length with the same workspace (reusing allocated buffers within this list) + test_data = [] for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion + ) + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) + ) + + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + for seq_len, x, residual, norm_weight, reference_output in test_data: if rank == 0: +<<<<<<< HEAD print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") @@ -215,6 +260,11 @@ def run_mnnvl_ar_full( # Non-fused case: Only AllReduce allreduce_result = torch.sum(x_full, dim=0) # AllReduce result reference_output = (allreduce_result,) +======= + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) +>>>>>>> bca4f5d9 (Passing the test.) # Run the test with the same workspace row_linear_residual_norm_fusion_forward( @@ -222,8 +272,6 @@ def run_mnnvl_ar_full( residual, norm_weight, eps, - hidden_size, - dtype, mapping, fusion, reference_output, @@ -272,82 +320,13 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize( "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], ) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) def test_mnnvl_allreduce_default_workspace( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): """Test MNNVL AllReduce with default workspace size.""" run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) - - -if __name__ == "__main__": - run_mnnvl_ar_full(None, [15], False, torch.bfloat16, 4096) - -# """Test with explicit workspace size""" - - -# @pytest.mark.parametrize( -# "seq_lens", -# [ -# [1, 4, 180], -# ], -# ) -# @pytest.mark.parametrize("fusion", [False, True]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -# def test_mnnvl_allreduce_explicit_workspace( -# monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int -# ): -# """Test MNNVL AllReduce with explicitly calculated workspace size.""" -# # Calculate workspace to fit the maximum sequence length -# # buffer shape: [3, 2, buffer_tokens, hidden_dim] -# explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) -# run_mnnvl_ar_full( -# monkeypatch, -# seq_lens, -# fusion, -# dtype, -# hidden_size, -# explicit_workspace_bytes=explicit_workspace_bytes, -# ) - - -# """Negative test: workspace too small""" - - -# @pytest.mark.parametrize("fusion", [False, True]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("hidden_size", [2048, 4096]) -# def test_mnnvl_allreduce_workspace_too_small( -# monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -# ): -# """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" -# # Use a large sequence length that won't fit in a small workspace -# seq_len = 180 - -# # Create a workspace that's too small (only enough for 10 tokens) -# small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - -# # Expect a ValueError with a message about buffer_M being too small -# with pytest.raises((ValueError, RuntimeError)) as exc_info: -# run_mnnvl_ar_full( -# monkeypatch, -# [seq_len], -# fusion, -# dtype, -# hidden_size, -# explicit_workspace_bytes=small_workspace_bytes, -# ) - -# # Verify the error message contains the expected text -# assert "greater than the buffer_M" in str(exc_info.value) From 9a6beec1a41a8b24092d01321bbabd3fca433f35 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 19:54:54 -0800 Subject: [PATCH 06/14] Remove debug prints and add compatability interface. --- flashinfer/comm/mnnvl.py | 9 - flashinfer/comm/trtllm_mnnvl_ar.py | 237 +++++++++++++++++++++- tests/comm/test_trtllm_mnnvl_allreduce.py | 18 +- 3 files changed, 238 insertions(+), 26 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 520f6e4880..787f243995 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -131,9 +131,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: return None - for addr in host_ptr_array: - print(f"DEBUG: ptr_array: 0x{addr:x}") - ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) @@ -719,9 +716,6 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if not self.is_multi_node: - return - # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): @@ -884,7 +878,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) else: # Implement the allgather logic with ipc socket - # TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help. all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() @@ -896,8 +889,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() cuda.cuCtxSynchronize() - print(f"[Rank {self.group_rank}] all_shareable_uc_handles: {all_shareable_uc_handles}") - # Import remote handles for p in range(self.group_size): if p != self.group_rank: diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index eae919e5e0..a9c7a026e4 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -11,6 +11,7 @@ from enum import Enum import torch +from typing_extensions import deprecated from flashinfer.comm.mapping import Mapping @@ -278,7 +279,7 @@ def trtllm_mnnvl_allreduce( return output -def trtllm_mnnvl_fused_allreduce_rmsnorm( +def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: torch.Tensor, residual_in: torch.Tensor, gamma: torch.Tensor, @@ -289,10 +290,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Performs MNNVL Allreduce + RMSNorm. + """Performs MNNVL Allreduce + Residual + RMSNorm. This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. - After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank. Args: @@ -307,8 +308,8 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. Returns: - output: Normalized tensor [num_tokens, hidden_dim] - residual_out: Residual output tensor [num_tokens, hidden_dim] + output: Add-residual and normalized tensor [num_tokens, hidden_dim] + residual_out: Add-residual tensor [num_tokens, hidden_dim] """ if epsilon is None: @@ -347,10 +348,6 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( ) ) - print( - f"[Rank {workspace.rank}] workspace.mc_ptr: {workspace.mc_ptr}, workspace.uc_ptrs_dev: {workspace.uc_ptrs_dev}, workspace.uc_ptr_local: {workspace.uc_ptr_local}" - ) - module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -369,3 +366,225 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( epsilon, ) return output, residual_out + + +# Legacy API that has been deprecated; Left for backward compatibility +@deprecated( + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" +) +def get_allreduce_mnnvl_workspace( + mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None +) -> Tuple[McastGPUBuffer, torch.Tensor, int]: + """Get workspace buffers needed for multi-node NVLink all-reduce operation. + + This function allocates and initializes the workspace buffers required for performing + multi-node NVLink all-reduce operations. It creates: + 1. A multicast GPU buffer for communication between nodes + 2. A flags tensor to track buffer state + 3. Maximum number of elements that can fit in the buffer + + The buffer size is calculated to efficiently handle common hidden dimensions + (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. + + Args: + mapping: Tensor parallel mapping configuration containing rank info + dtype: Data type of the tensors being reduced + buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens + + Returns: + Tuple containing: + - McastGPUBuffer: Multicast buffer for inter-node communication + - torch.Tensor: Buffer flags tensor tracking state + - int: Maximum number of elements that can fit in buffer + """ + # buffer shape: [3, 2, buffer_tokens, hidden_dim] + stride = 3 * 2 * dtype.itemsize + # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 + # max_num_elements must be a multiple of 286720 + lcm_hidden_dim = 286720 + TARGET_WORKSPACE_SIZE_BYTES = ( + buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + ) + buffer_size_in_bytes = math.ceil( + TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) + ) * (lcm_hidden_dim * stride) + + # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. + workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes) + + mcast_buffer = workspace.mcast_buffer_handle + buffer_flags = workspace.buffer_flags + max_num_elements = workspace.buffer_size_bytes // stride + + return ( + mcast_buffer, + buffer_flags, + max_num_elements, + ) + + +@deprecated( + "trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future." +) +def trtllm_mnnvl_all_reduce( + inp: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + buffer_M: int, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + wait_for_results: bool, + launch_with_pdl: bool, + out: Optional[torch.Tensor] = None, +) -> None: + """Perform a multi-node NVLink all-reduce operation across multiple GPUs. + + This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) + technology to efficiently combine tensors across multiple GPUs and nodes. + + There are 3 steps: + 1. scatter each GPU's input shard to the right unicast buffer + 2. perform all-reduce on each GPU + 3. broadcast the result to all GPUs + + Args: + inp: Local Input Shard + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to device buffer pointers as an integer + buffer_M: Maximum number of elements // hidden_dim + buffer_flags_mnnvl: Tensor containing buffer state flags + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + wait_for_results: If True, store the result to out + launch_with_pdl: If True, launch using Programmatic Dependent Launch + [Optional] out: Output tensor to store the result (required if wait_for_results is True) + + """ + + if len(inp.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if inp.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) + module = get_trtllm_mnnvl_comm_module() + module.trtllm_mnnvl_allreduce_fusion( + input, + multicast_buffer_ptr, + buffer_ptrs_dev, + 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. + buffer_flags_mnnvl, + nranks, + rank, + False, # No RMSNorm Fusion + launch_with_pdl, + False, # Use two-shot + out, + None, + None, + None, + None, + ) + + +@deprecated( + "trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future." +) +def trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output: torch.Tensor, + normed_output: torch.Tensor, + shard_input: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + unicast_ptr: int, # Local unicast buffer pointer + buffer_M: int, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + gamma: torch.Tensor, + epsilon: float, + residual: torch.Tensor, + launch_with_pdl: bool, +) -> None: + """Performs MNNVL TwoShot Allreduce + RMSNorm. + + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. + After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + Note: multicast buffer is the same as the unicast buffer for the current rank. + + Args: + prenorm_output: Output tensor for prenorm results + normed_output: Output tensor for normalized results + shard_input: Input tensor shard + multicast_buffer_ptr: Pointer address as integer for multicast buffer + buffer_ptrs_dev: Pointer address as integer for device buffer pointers + unicast_ptr: Pointer address as integer for unicast buffer + buffer_M: Maximum number of elements // hidden_dim + buffer_flags_mnnvl: Buffer flags for synchronization + nranks: Number of ranks in the tensor parallel group + rank: Current rank in the tensor parallel group + gamma: The gamma (norm weight) parameter for RMSNorm + epsilon: The epsilon parameter for RMSNorm + residual: The residual tensor to add + launch_with_pdl: Whether to launch with PDL + + """ + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if shard_input.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + if len(residual.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}." + ) + if gamma.numel() != shard_input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements." + ) + + if len(normed_output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}." + ) + + if len(prenorm_output.shape) != 2: + raise ValueError( + f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + module.trtllm_mnnvl_allreduce_fusion( + shard_input, + multicast_buffer_ptr, + buffer_ptrs_dev, + unicast_ptr, + buffer_flags_mnnvl, + nranks, + rank, + True, # RMSNorm Fusion + launch_with_pdl, + False, + normed_output, + prenorm_output, + residual, + gamma, + epsilon, + ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 6b89661650..b77c5a91d1 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -51,14 +51,16 @@ def func( if enable_fusion: trtllm_mnnvl_ar.mpi_barrier() - output, residual_out = trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( - input, - residual, - norm_weight, - workspace, - eps, - launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + output, residual_out = ( + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input, + residual, + norm_weight, + workspace, + eps, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) ) return output.view(shape), residual_out.view(shape) From a4d1a1757e007c2abbda400c48b3a368289aa277 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 20:12:22 -0800 Subject: [PATCH 07/14] Incorporate 2056; Add test for legacy APIs --- flashinfer/comm/mnnvl.py | 4 + flashinfer/comm/trtllm_mnnvl_ar.py | 44 +++-- tests/comm/test_trtllm_mnnvl_allreduce.py | 227 +++++++++++++++++++--- 3 files changed, 230 insertions(+), 45 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 787f243995..6ca3a4b866 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -1033,7 +1033,11 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used +<<<<<<< HEAD comm_backend_for_handle_transfer: Communication backend for handle transfer +======= + comm_backend_for_handle_transfer: The communicator to use for handle transfer +>>>>>>> a2670e8c (Incorporate 2056; Add test for legacy APIs) """ self.mcast_device_memory = McastDeviceMemory( buf_size, diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index a9c7a026e4..82c40a7c83 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -17,7 +17,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer, CommBackend +from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend def mpi_barrier(): @@ -39,14 +39,18 @@ def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dty # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size -# TODO(Refactor): Consider moving this to a configuration class or file MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 class MNNVLAllreduceFusionWorkspace: NUM_LAMPORT_BUFFERS = 3 - def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None): + def __init__( + self, + mapping: Mapping, + buffer_size_in_bytes: Optional[int] = None, + comm_backend: Optional[CommBackend] = None, + ): """ Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. @@ -60,7 +64,8 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) - + if comm_backend is None: + comm_backend = MPIBackend() if buffer_size_in_bytes > (2**32 - 1): raise ValueError( f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." @@ -79,14 +84,14 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node(), + comm_backend, ) # We use FP32 for sentinel value regardless of the real dtype self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) # Wait until the initialization is done torch.cuda.synchronize() - # FIXME: We are assuming using the COMM_WORLD. - mpi_barrier() + comm_backend.barrier() # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer @@ -373,7 +378,10 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" ) def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, + buffer_size_in_bytes: Optional[int] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -402,15 +410,13 @@ def get_allreduce_mnnvl_workspace( # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = ( - buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + TARGET_WORKSPACE_SIZE_BYTES = buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + buffer_size_in_bytes = math.ceil(TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)) * ( + lcm_hidden_dim * stride ) - buffer_size_in_bytes = math.ceil( - TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) - ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes) + workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer) mcast_buffer = workspace.mcast_buffer_handle buffer_flags = workspace.buffer_flags @@ -463,9 +469,7 @@ def trtllm_mnnvl_all_reduce( """ if len(inp.shape) != 2: - raise ValueError( - f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." - ) + raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: @@ -474,12 +478,12 @@ def trtllm_mnnvl_all_reduce( ) # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. - assert wait_for_results and (out is not None), ( - "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." - ) + assert wait_for_results and ( + out is not None + ), "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_allreduce_fusion( - input, + inp, multicast_buffer_ptr, buffer_ptrs_dev, 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index b77c5a91d1..461e1527ac 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -110,6 +110,131 @@ def func( ) +@torch.inference_mode() +def row_linear_residual_norm_fusion_forward_legacy( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + hidden_size: int, + dtype: torch.dtype, + mapping: Mapping, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + multicast_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, + max_num_elements_mnnvl: int, + buffer_flags_mnnvl: torch.Tensor, +): + tensor_parallel_size = mapping.tp_size + tensor_parallel_rank = mapping.tp_rank + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + buffer_M = max_num_elements_mnnvl // hidden_size + + if enable_fusion: + use_pdl = True + + prenorm_output = torch.empty_like(residual) + normed_output = torch.empty_like(residual) + + trtllm_mnnvl_ar.mpi_barrier() + + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output, + normed_output, + input, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + norm_weight, + eps, + residual, + use_pdl, + ) + + return normed_output.view(shape), prenorm_output.view(shape) + + else: + output = torch.empty_like(input) + + trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( + input, + multicast_ptr, + buffer_ptrs_dev, + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + True, # wait_for_results + False, # launch_with_pdl + output, # Need to provide output tensor since we are writing them out. + ) + return (output.view(shape),) + + output = func( + x.clone(), + residual.clone(), + norm_weight, + eps, + fusion, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + """Helper function to run the core MNNVL AllReduce test logic""" @@ -155,7 +280,13 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion def run_mnnvl_ar_full( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + legacy_explicit_workspace_bytes: int = None, + legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. @@ -211,14 +342,30 @@ def run_mnnvl_ar_full( failure_message = "" try: - required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( - mapping.tp_size, - max(seq_lens), - hidden_size, - dtype, - trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, - ) - workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) + if legacy_api: + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes + ) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) + + else: + required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( + mapping.tp_size, + max(seq_lens), + hidden_size, + dtype, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( + mapping, required_workspace_bytes + ) test_data = [] for seq_len in seq_lens: @@ -266,19 +413,34 @@ def run_mnnvl_ar_full( print( f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) ->>>>>>> bca4f5d9 (Passing the test.) - - # Run the test with the same workspace - row_linear_residual_norm_fusion_forward( - x, - residual, - norm_weight, - eps, - mapping, - fusion, - reference_output, - workspace, - ) + if legacy_api: + row_linear_residual_norm_fusion_forward_legacy( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + fusion, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + ) + else: + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + mapping, + fusion, + reference_output, + workspace, + ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() @@ -327,8 +489,23 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) -def test_mnnvl_allreduce_default_workspace( +def test_mnnvl_allreduce_refactored( + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +): + """Test MNNVL AllReduce with refactored API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=False + ) + + +@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]]) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +def test_mnnvl_allreduce_legacy( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with default workspace size.""" - run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) + """Test MNNVL AllReduce with legacy API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True + ) From 01564e97d007485bc971018a4ca89e6f8cb40093 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 20:31:37 -0800 Subject: [PATCH 08/14] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 3 +-- flashinfer/comm/trtllm_mnnvl_ar.py | 3 --- tests/comm/test_trtllm_mnnvl_allreduce.py | 5 ++--- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index c7215a4241..5049344872 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -103,8 +103,7 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt status = twoshotAllreduceFusionDispatch(params); } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); + << "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status); }); } diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 82c40a7c83..7770cb815e 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -192,9 +192,6 @@ def trtllm_mnnvl_allreduce_fusion( gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ - print( - f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" - ) module.trtllm_mnnvl_allreduce_fusion( input, multicast_buffer_ptr, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 461e1527ac..cb5425f5e2 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,6 @@ # Check torch version: -from typing import Tuple, Optional +import traceback +from typing import Tuple import pytest import torch @@ -451,8 +452,6 @@ def run_mnnvl_ar_full( rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) - import traceback - print(traceback.format_exc()) # Gather failure status from all ranks for logging From 775918d5d8d6581815129ca72f7638f098ce1e72 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 20 Nov 2025 13:36:41 -0800 Subject: [PATCH 09/14] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 11 +++++++++-- flashinfer/comm/mnnvl.py | 18 ++++++++++++++++-- include/flashinfer/utils.cuh | 13 +++++++++---- tests/comm/test_trtllm_mnnvl_allreduce.py | 4 ++-- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 5049344872..dea2ddd039 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -53,11 +53,18 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() && + epsilon.has_value()) || !rmsnorm_fusion) - << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + << "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is " + "true"; if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1) + << ")"; TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 6ca3a4b866..b6bbdd3906 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -716,6 +716,9 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return + if hasattr(self, "_ipc_socket"): + self._ipc_socket.close() + # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): @@ -864,7 +867,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Allocate local GPU memory self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) - # Export local handle to fabric handle + # Export local handle to fabric handle or FD local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], @@ -898,6 +901,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._shareable_handle_type, ) ) + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(all_shareable_uc_handles[p]) # Initialize multicasting if self.group_rank == 0: @@ -943,7 +952,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._shareable_handle_type, ) ) - + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(shareable_mc_handle) # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 20c19a0eae..8481aabf39 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -335,16 +336,20 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +// This function is thread-safe and cached the sm_count. +// But it will only check the current CUDA device, thus assuming each process handles single GPU. inline int GetCudaMultiProcessorCount() { - static int sm_count = 0; - if (sm_count == 0) { + static std::atomic sm_count{0}; + int cached = sm_count.load(std::memory_order_relaxed); + if (cached == 0) { int device_id; cudaGetDevice(&device_id); cudaDeviceProp device_prop; cudaGetDeviceProperties(&device_prop, device_id); - sm_count = device_prop.multiProcessorCount; + cached = device_prop.multiProcessorCount; + sm_count.store(cached, std::memory_order_relaxed); } - return sm_count; + return cached; } template diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index cb5425f5e2..43437faf4b 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,6 +1,6 @@ # Check torch version: import traceback -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -286,7 +286,7 @@ def run_mnnvl_ar_full( fusion: bool, dtype: torch.dtype, hidden_size: int, - legacy_explicit_workspace_bytes: int = None, + legacy_explicit_workspace_bytes: Optional[int] = None, legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. From 45a5b828c9f3a94ad25b5fe350bb978c16296431 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 21 Nov 2025 14:37:32 -0800 Subject: [PATCH 10/14] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 2 - flashinfer/comm/trtllm_mnnvl_ar.py | 83 +++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index dea2ddd039..e1c998d8ea 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,8 +26,6 @@ using tvm::ffi::Optional; } \ }() -// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this -// level void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 7770cb815e..4a69b83d8a 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -33,9 +33,18 @@ class MNNVLAllreduceFusionStrategy(Enum): AUTO = 99 @staticmethod +<<<<<<< HEAD def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: +======= + def select_strategy( + tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype + ) -> "MNNVLAllreduceFusionStrategy": +>>>>>>> c6ed1472 (Address review comments.) elem_size = torch.tensor([], dtype=dtype).element_size() - return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD + if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: + return MNNVLAllreduceFusionStrategy.ONESHOT + else: + return MNNVLAllreduceFusionStrategy.TWOSHOT # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size @@ -52,15 +61,15 @@ def __init__( comm_backend: Optional[CommBackend] = None, ): """ - Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. + Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. Args: mapping: Mapping configuration containing rank info buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. """ if buffer_size_in_bytes is None: - # Default to 16MB workspace size if not provided - buffer_size_in_bytes = 16 * (1024**2) + # Default to 512MB workspace size if not provided + buffer_size_in_bytes = 512 * (1024**2) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) @@ -108,7 +117,28 @@ def __init__( self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + @functools.cache + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> bool: + """ + Calculate the required buffer size for a given problem size. + """ + required_buffer_size = self.get_required_buffer_size_bytes( + tp_size, num_tokens, hidden_dim, dtype, strategy + ) + if required_buffer_size > self.buffer_size_bytes: + return False + else: + return True + @staticmethod + @functools.cache def get_required_buffer_size_bytes( tp_size: int, num_tokens: int, @@ -120,10 +150,19 @@ def get_required_buffer_size_bytes( Calculate the required buffer size for a given problem size. """ elem_size = torch.tensor([], dtype=dtype).element_size() +<<<<<<< HEAD is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot ): +======= + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + tp_size, num_tokens, hidden_dim, dtype + ) + + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: +>>>>>>> c6ed1472 (Address review comments.) # For one-shot, each rank needs to store num_tokens * tp_size tokens buffer_size = num_tokens * hidden_dim * tp_size * elem_size else: @@ -256,10 +295,25 @@ def trtllm_mnnvl_allreduce( module = get_trtllm_mnnvl_comm_module() +<<<<<<< HEAD use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) ) +======= + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + +>>>>>>> c6ed1472 (Address review comments.) module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -270,7 +324,7 @@ def trtllm_mnnvl_allreduce( workspace.rank, False, # No RMSNorm Fusion launch_with_pdl, - use_oneshot, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, output, None, None, @@ -340,15 +394,16 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( module = get_trtllm_mnnvl_comm_module() - use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot( - workspace.tp_size, - input.shape[0], - input.shape[1], - input.dtype, + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." ) - ) module.trtllm_mnnvl_allreduce_fusion( input, @@ -360,7 +415,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( workspace.rank, True, # RMSNorm Fusion launch_with_pdl, - use_oneshot, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, output, residual_out, residual_in, From 815aaf33dc6f3479db57e1c066a679e827fd2ca2 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 21 Nov 2025 15:20:35 -0800 Subject: [PATCH 11/14] Rounding up workspace size according to allocation (page size). --- flashinfer/comm/mnnvl.py | 19 ++++++++++++++----- flashinfer/comm/trtllm_mnnvl_ar.py | 30 +++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index b6bbdd3906..66c47e4c9c 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -803,6 +803,14 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size + def get_allocation_size(self) -> int: + """Get the total allocation size (including signal pad)""" + return self.allocation_size + + def get_usable_buffer_size(self) -> int: + """Get the usable buffer size (excluding signal pad)""" + return self.allocation_size - self.SIGNAL_PAD_SIZE + def _init_ipc_socket(self): if self.group_rank == 0: # Gnerate the opId @@ -838,7 +846,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, - cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, + cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED, ) ) @@ -1015,8 +1023,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): else: raise ValueError(f"Unsupported dtype: {dtype}") - # Calculate number of elements that fit in allocation_size - num_elements = self.allocation_size // dsize + # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. + num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) @@ -1042,7 +1050,7 @@ def __init__( Constructor for McastGpuBuffer. Args: - buf_size: The total size of the buffer in bytes + buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements. group_size: The number of ranks in the communication group group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation @@ -1061,7 +1069,8 @@ def __init__( mn_nvlink, comm_backend_for_handle_transfer, ) - self.buf_size = buf_size + # Update buf_size to reflect the actual usable buffer size after allocation + self.buf_size = self.mcast_device_memory.get_usable_buffer_size() self.local_device = device def lamport_initialize(self, rank: int, dtype: torch.dtype): diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 4a69b83d8a..4244e00aa4 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -65,11 +65,11 @@ def __init__( Args: mapping: Mapping configuration containing rank info - buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. + buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. """ if buffer_size_in_bytes is None: - # Default to 512MB workspace size if not provided - buffer_size_in_bytes = 512 * (1024**2) + # Default to 16MB workspace size if not provided + buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) @@ -80,15 +80,18 @@ def __init__( f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." ) - self.buffer_size_bytes = buffer_size_in_bytes - self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + # Calculate total requested workspace size + requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + self.rank = mapping.tp_rank self.tp_size = mapping.tp_size logging.debug( - f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes." + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer." ) + + # Allocate the workspace self.mcast_buffer_handle = McastGPUBuffer( - self.workspace_size_bytes, + requested_workspace_size, mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), @@ -96,6 +99,19 @@ def __init__( comm_backend, ) + # Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer) + allocated_size = self.mcast_buffer_handle.buf_size + # We want the buffer size to be aligned to 16B which is the granularity for buffer management. + self.buffer_size_bytes = ( + math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 + ) + # This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size. + self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS + + logging.debug( + f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes." + ) + # We use FP32 for sentinel value regardless of the real dtype self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) # Wait until the initialization is done From 68a9b9b8ef33705d293d1e5e3030e0cc6d1aa45c Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 15:56:04 -0800 Subject: [PATCH 12/14] Fix rebasing errors. --- flashinfer/comm/mnnvl.py | 233 +++++++++++++++++++++-------- flashinfer/comm/trtllm_mnnvl_ar.py | 64 ++++---- 2 files changed, 204 insertions(+), 93 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 66c47e4c9c..3128a9874a 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -41,7 +41,8 @@ from cuda import cuda except ImportError as e: raise ImportError( - "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." + "Could not import the 'cuda' module. " + "Please install cuda-python that matches your CUDA version." ) from e from ..cuda_utils import checkCudaErrors @@ -62,7 +63,9 @@ def round_up(val: int, gran: int) -> int: return (val + gran - 1) & ~(gran - 1) -def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, device_id: int) -> torch.Tensor: +def create_tensor_from_cuda_memory( + ptr: int, shape: tuple, dtype: torch.dtype, device_id: int +) -> torch.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. @@ -84,7 +87,9 @@ def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, d element_size = torch.tensor([], dtype=dtype).element_size() # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) - capsule_wrapper = create_dlpack_capsule(ptr, element_size, element_size, numel, dtype, device_id) + capsule_wrapper = create_dlpack_capsule( + ptr, element_size, element_size, numel, dtype, device_id + ) # Convert to tensor and reshape tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) @@ -136,7 +141,9 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: size_in_bytes = ctypes.sizeof(c_array) device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) - checkCudaErrors(cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)) + checkCudaErrors( + cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) + ) # c_array should be freed by GC return int(device_ptr) @@ -292,14 +299,18 @@ def initialize(): @staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) + comm = config.comm_backend.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) MnnvlMemory.comm = comm # type: ignore[assignment] @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm - comm = MpiComm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) + comm = MpiComm().Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) MnnvlMemory.comm = comm return comm @@ -315,7 +326,9 @@ def get_allocation_prop(dev_id: int): arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch if is_on_aarch64: - allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + allocation_prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) else: allocation_prop.requestedHandleTypes = ( cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR @@ -331,19 +344,27 @@ def get_allocation_granularity(dev_id: int): option = cuda.CUmemAllocationGranularity_flags( cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ) - granularity = checkCudaErrors(cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)) + granularity = checkCudaErrors( + cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option) + ) MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): - page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size + page_count = ( + size + MnnvlMemory.fabric_page_size - 1 + ) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size - logging.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}") + logging.info( + f"[MnnvlMemory] creating address with stride={current_rank_stride}" + ) comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size - ptr = checkCudaErrors(cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)) + ptr = checkCudaErrors( + cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0) + ) MnnvlMemory.current_start_address = int(ptr) MnnvlMemory.current_rank_stride = current_rank_stride MnnvlMemory.current_mem_offset = 0 @@ -354,29 +375,44 @@ def open_mnnvl_memory(mapping: Mapping, size: int): dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id - assert ( - dev_id == MnnvlMemory.dev_id - ), f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" + assert dev_id == MnnvlMemory.dev_id, ( + f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" + ) comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size - assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size." + assert all(x == size for x in all_rank_allocate_sizes), ( + "Not all rank allocating same size." + ) granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride: + if ( + MnnvlMemory.current_mem_offset + aligned_size + > MnnvlMemory.current_rank_stride + ): MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) - assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride + assert ( + MnnvlMemory.current_mem_offset + aligned_size + <= MnnvlMemory.current_rank_stride + ) allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) - allocated_mem_handle = checkCudaErrors(cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)) + allocated_mem_handle = checkCudaErrors( + cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) + ) exported_fabric_handle = checkCudaErrors( - cuda.cuMemExportToShareableHandle(allocated_mem_handle, allocation_prop.requestedHandleTypes, 0) + cuda.cuMemExportToShareableHandle( + allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 + ) ) - if allocation_prop.requestedHandleTypes == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + allocation_prop.requestedHandleTypes + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): all_handles_data = comm.allgather(exported_fabric_handle.data) else: all_handles_data = comm.allgather(exported_fabric_handle) @@ -390,7 +426,9 @@ def open_mnnvl_memory(mapping: Mapping, size: int): pidfd = syscall(SYS_pidfd_open, pid, 0) if pidfd < 0: err = ctypes.get_errno() - raise RuntimeError(f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}") + raise RuntimeError( + f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" + ) pidfds.append(pidfd) remote_fds = [] @@ -405,7 +443,9 @@ def open_mnnvl_memory(mapping: Mapping, size: int): "to your docker run command." ) else: - error_msg += " This may be due to kernel version (requires Linux 5.6+)." + error_msg += ( + " This may be due to kernel version (requires Linux 5.6+)." + ) raise RuntimeError(error_msg) remote_fds.append(remote_fd) @@ -421,19 +461,27 @@ def open_mnnvl_memory(mapping: Mapping, size: int): for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( - MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset + MnnvlMemory.current_start_address + + MnnvlMemory.current_rank_stride * i + + MnnvlMemory.current_mem_offset ) if i == comm_rank: # Local memory mapping mem_handles[i] = allocated_mem_handle - checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0)) + checkCudaErrors( + cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0) + ) else: # Fabric memory mapping imported_mem_handle = checkCudaErrors( - cuda.cuMemImportFromShareableHandle(remote_handle_data, allocation_prop.requestedHandleTypes) + cuda.cuMemImportFromShareableHandle( + remote_handle_data, allocation_prop.requestedHandleTypes + ) ) mem_handles[i] = imported_mem_handle - checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0)) + checkCudaErrors( + cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) + ) checkCudaErrors(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)) @@ -490,14 +538,20 @@ def support_nvlink(need_all_up: bool = True): available_links = 0 for link_idx in range(link_count): try: - if pynvml.nvmlDeviceGetNvLinkCapability(handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED): + if pynvml.nvmlDeviceGetNvLinkCapability( + handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED + ): available_links += 1 is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx) if is_active: active_links += 1 except pynvml.NVMLError_NotSupported: continue - return active_links == available_links and available_links > 0 if need_all_up else available_links > 0 + return ( + active_links == available_links and available_links > 0 + if need_all_up + else available_links > 0 + ) @staticmethod def supports_mnnvl() -> bool: @@ -585,14 +639,18 @@ def recv_fd(self): fds = array.array("i") msg, ancdata, flags, addr = self.sock.recvmsg( 1, - socket.CMSG_SPACE(fds.itemsize), # Buffer size for dummy data # Ancillary data size + socket.CMSG_SPACE( + fds.itemsize + ), # Buffer size for dummy data # Ancillary data size ) # Extract file descriptor from ancillary data for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: fds = array.array("i") - fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + fds.frombytes( + cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)] + ) return fds[0] raise RuntimeError("No file descriptor received") @@ -617,7 +675,6 @@ def __init__( device_idx: int, is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -653,7 +710,9 @@ def __init__( self.signal_pads_dev = 0 # std::vector mSignalPadsDev self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle - self.uc_handles: List[int] = [] # std::vector mUcHandles + self.uc_handles: List[ + int + ] = [] # std::vector mUcHandles self._shareable_handle_type = None @@ -669,7 +728,9 @@ def __init__( ) ) if multicast_supported == 0: - raise RuntimeError("[McastDeviceMemory] Device does not support multicasting.") + raise RuntimeError( + "[McastDeviceMemory] Device does not support multicasting." + ) # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) @@ -689,13 +750,19 @@ def __init__( ) ) if fabric_handle_supported == 0: - raise RuntimeError("[McastDeviceMemory] Device does not support fabric handle.") + raise RuntimeError( + "[McastDeviceMemory] Device does not support fabric handle." + ) # Use fabric handle for multi-node NVLS - self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) else: self._init_ipc_socket() # Use NVLink handle for single-node NVLS - self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads @@ -703,7 +770,9 @@ def __init__( for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: - checkCudaErrors(cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)) + checkCudaErrors( + cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) + ) # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) @@ -746,19 +815,29 @@ def __del__(self): checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: - checkCudaErrors(cuda.cuMemUnmap(self.uc_ptrs[rank], self.allocation_size)) + checkCudaErrors( + cuda.cuMemUnmap( + self.uc_ptrs[rank], self.allocation_size + ) + ) except Exception as e: - print(f"Destructor: Failed to release UC handle for rank {rank}: {e}") + print( + f"Destructor: Failed to release UC handle for rank {rank}: {e}" + ) # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: - checkCudaErrors(cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size)) + checkCudaErrors( + cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) + ) # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) - checkCudaErrors(cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size)) + checkCudaErrors( + cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) + ) checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") @@ -828,7 +907,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): current_device = checkCudaErrors(cuda.cuCtxGetDevice()) if int(current_device) != self.device_idx: - print(f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}") + print( + f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" + ) except Exception as e: print(f"Error checking CUDA context: {e}") @@ -837,7 +918,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): allocation_prop.requestedHandleTypes = self._shareable_handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() - allocation_prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + allocation_prop.location.type = ( + cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + ) allocation_prop.location.id = self.device_idx allocation_prop.allocFlags.gpuDirectRDMACapable = 1 @@ -851,7 +934,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); - self.allocation_size = round_up(buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity) + self.allocation_size = round_up( + buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity + ) # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() @@ -873,7 +958,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self.uc_handles = [0] * self.group_size # Allocate local GPU memory - self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) + self.uc_handles[self.group_rank] = checkCudaErrors( + cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) + ) # Export local handle to fabric handle or FD local_shareable_uc_handle = checkCudaErrors( @@ -884,9 +971,14 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) - if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): # All-gather fabric handles - all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) + all_shareable_uc_handles = self.comm_backend.allgather( + local_shareable_uc_handle.data + ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size @@ -931,7 +1023,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) else: shareable_mc_handle = None - if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): # Broadcast multicast handle shareable_mc_handle = self.comm_backend.bcast( shareable_mc_handle.data if shareable_mc_handle else None, root=0 @@ -975,7 +1070,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors(cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0)) + uc_base_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) + ) self.uc_base_ptr = uc_base_ptr # Store for cleanup # Set up memory access descriptor @@ -989,15 +1086,27 @@ def _alloc_mn_mcast_mem(self, buf_size: int): for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors(cuda.cuMemMap(self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0)) + checkCudaErrors( + cuda.cuMemMap( + self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 + ) + ) # Set memory access permissions - checkCudaErrors(cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1)) + checkCudaErrors( + cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) + ) # Bind MC pointer - self.mc_ptr = checkCudaErrors(cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0)) - checkCudaErrors(cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0)) - checkCudaErrors(cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1)) + self.mc_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + ) + checkCudaErrors( + cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) + ) + checkCudaErrors( + cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) + ) # Bind memory to multicast checkCudaErrors( @@ -1026,7 +1135,9 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize - checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) + checkCudaErrors( + memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) + ) class McastGPUBuffer: @@ -1055,11 +1166,7 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used -<<<<<<< HEAD comm_backend_for_handle_transfer: Communication backend for handle transfer -======= - comm_backend_for_handle_transfer: The communicator to use for handle transfer ->>>>>>> a2670e8c (Incorporate 2056; Add test for legacy APIs) """ self.mcast_device_memory = McastDeviceMemory( buf_size, @@ -1076,7 +1183,9 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: + def get_multicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. @@ -1092,7 +1201,9 @@ def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. raise NotImplementedError("Not implemented yet") - def get_unicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: """ Returns a PyTorch tensor view of the unicast buffer portion. """ diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 4244e00aa4..afdd580910 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -33,13 +33,9 @@ class MNNVLAllreduceFusionStrategy(Enum): AUTO = 99 @staticmethod -<<<<<<< HEAD - def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: -======= def select_strategy( tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype ) -> "MNNVLAllreduceFusionStrategy": ->>>>>>> c6ed1472 (Address review comments.) elem_size = torch.tensor([], dtype=dtype).element_size() if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: return MNNVLAllreduceFusionStrategy.ONESHOT @@ -72,7 +68,9 @@ def __init__( buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB - buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * ( + 8 * (1024**2) + ) if comm_backend is None: comm_backend = MPIBackend() if buffer_size_in_bytes > (2**32 - 1): @@ -166,25 +164,20 @@ def get_required_buffer_size_bytes( Calculate the required buffer size for a given problem size. """ elem_size = torch.tensor([], dtype=dtype).element_size() -<<<<<<< HEAD - is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) - if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot - ): -======= if strategy == MNNVLAllreduceFusionStrategy.AUTO: strategy = MNNVLAllreduceFusionStrategy.select_strategy( tp_size, num_tokens, hidden_dim, dtype ) if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: ->>>>>>> c6ed1472 (Address review comments.) # For one-shot, each rank needs to store num_tokens * tp_size tokens buffer_size = num_tokens * hidden_dim * tp_size * elem_size else: # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. # 2 Stage is required for the two-shot allreduce. - buffer_size = 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + buffer_size = ( + 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + ) return buffer_size @@ -302,21 +295,19 @@ def trtllm_mnnvl_allreduce( # Check ndims here as the shape check is done in the kernel launch code. if len(input.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) if output is None: output = torch.empty_like(input) elif len(output.shape) != 2: - raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) module = get_trtllm_mnnvl_comm_module() -<<<<<<< HEAD - use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) - ) -======= if strategy == MNNVLAllreduceFusionStrategy.AUTO: strategy = MNNVLAllreduceFusionStrategy.select_strategy( workspace.tp_size, input.shape[0], input.shape[1], input.dtype @@ -329,7 +320,6 @@ def trtllm_mnnvl_allreduce( f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." ) ->>>>>>> c6ed1472 (Address review comments.) module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -388,7 +378,9 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( epsilon = torch.finfo(input.dtype).eps if len(input.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) if len(residual_in.shape) != 2: raise ValueError( f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." @@ -400,7 +392,9 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( if output is None: output = torch.empty_like(input) elif len(output.shape) != 2: - raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) if residual_out is None: residual_out = torch.empty_like(residual_in) elif len(residual_out.shape) != 2: @@ -478,13 +472,17 @@ def get_allreduce_mnnvl_workspace( # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 - buffer_size_in_bytes = math.ceil(TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)) * ( - lcm_hidden_dim * stride + TARGET_WORKSPACE_SIZE_BYTES = ( + buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 ) + buffer_size_in_bytes = math.ceil( + TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) + ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer) + workspace = MNNVLAllreduceFusionWorkspace( + mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer + ) mcast_buffer = workspace.mcast_buffer_handle buffer_flags = workspace.buffer_flags @@ -537,7 +535,9 @@ def trtllm_mnnvl_all_reduce( """ if len(inp.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." + ) # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: @@ -546,9 +546,9 @@ def trtllm_mnnvl_all_reduce( ) # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. - assert wait_for_results and ( - out is not None - ), "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_allreduce_fusion( inp, From 9e11752cbe5872a0be82928e30e735ed2eae85fd Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 16:20:14 -0800 Subject: [PATCH 13/14] Fix rebase errors. --- tests/comm/test_trtllm_mnnvl_allreduce.py | 51 ++--------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 43437faf4b..cf93b1af6c 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -8,7 +8,6 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping -from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -26,15 +25,7 @@ def row_linear_residual_norm_fusion_forward( workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank -<<<<<<< HEAD - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - comm.barrier() -======= MPI.COMM_WORLD.barrier() ->>>>>>> bca4f5d9 (Passing the test.) def func( input, @@ -322,19 +313,12 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: -<<<<<<< HEAD - print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") - print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") - - tensor_parallel_size = world_size -======= print( f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" ) print( f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" ) ->>>>>>> bca4f5d9 (Passing the test.) eps = 1e-5 torch.manual_seed(42 + rank) @@ -380,37 +364,6 @@ def run_mnnvl_ar_full( # Test each sequence length with the same workspace (reusing allocated buffers within this list) for seq_len, x, residual, norm_weight, reference_output in test_data: if rank == 0: -<<<<<<< HEAD - print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") - print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") - - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), - dtype=dtype, - device=torch.device("cuda"), - ) - residual = torch.randn((seq_len, hidden_size), dtype=dtype, device=torch.device("cuda")) - norm_weight = torch.randn((hidden_size,), dtype=dtype, device=torch.device("cuda")) - - # Each rank gets its slice of the input - x = x_full[rank, :, :] - - # Compute reference output based on fusion mode - reference_output: Tuple[torch.Tensor, ...] = None - if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual - print("Device of residual_out:{}, norm_weight:{}".format(residual_out.device, norm_weight.device)) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - - reference_output = (norm_out, residual_out) - else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - reference_output = (allreduce_result,) -======= print( f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) @@ -446,7 +399,9 @@ def run_mnnvl_ar_full( # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}") + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" + ) except Exception as e: rank_failed = True From 4a5faeff6509c97d454bfce7a9369565f93dcd79 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 16:24:27 -0800 Subject: [PATCH 14/14] Refactor mcast device memory. --- flashinfer/comm/mnnvl.py | 314 +++++++++++++++++++++++---------------- 1 file changed, 185 insertions(+), 129 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 3128a9874a..13ca4f534d 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -663,6 +663,107 @@ def close(self): os.unlink(self.socket_path) +class HandleExchanger(ABC): + """Abstract interface for exchanging CUDA shareable handles across ranks.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + self.comm = comm_backend + self.rank = group_rank + self.size = group_size + + @property + @abstractmethod + def handle_type(self) -> cuda.CUmemAllocationHandleType: + """The CUDA handle type this exchanger works with.""" + ... + + @abstractmethod + def allgather(self, local_handle) -> List: + """All-gather shareable handles from all ranks.""" + ... + + @abstractmethod + def broadcast(self, handle, root: int): + """Broadcast a handle from root to all ranks.""" + ... + + @abstractmethod + def cleanup(self, handle) -> None: ... + + @abstractmethod + def close(self) -> None: ... + + +class FabricHandleExchanger(HandleExchanger): + """Handle exchange using CUDA Fabric handles via MPI/collective backend.""" + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + + def allgather(self, local_handle) -> List: + return self.comm.allgather(local_handle.data) + + def broadcast(self, handle, root: int): + return self.comm.bcast(handle.data if handle else None, root=root) + + def cleanup(self, handle) -> None: + pass # No cleanup needed for Fabric handles. + + def close(self) -> None: + pass # No close needed for Fabric handles. + + +class PosixFDHandleExchanger(HandleExchanger): + """Handle exchange using POSIX file descriptors via IPC sockets.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + super().__init__(comm_backend, group_rank, group_size) + self._socket = self._init_ipc_socket() + + def _init_ipc_socket(self) -> IpcSocket: + if self.rank == 0: + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm.bcast(opId, root=0) + return IpcSocket(self.rank, opId) + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + + def allgather(self, local_handle) -> List: + result = [None] * self.size + for i in range(self.size): + self.comm.barrier() + self._socket.send_fd(local_handle, (self.rank + i) % self.size) + src = (self.rank + self.size - i) % self.size + result[src] = self._socket.recv_fd() + return result + + def broadcast(self, handle, root: int): + if self.rank == root: + for p in range(1, self.size): + self.comm.barrier() + self._socket.send_fd(handle, p) + return handle + else: + # Ordered receive to avoid race condition + for _ in range(self.rank): + self.comm.barrier() + result = self._socket.recv_fd() + for _ in range(self.size - self.rank - 1): + self.comm.barrier() + return result + + def cleanup(self, handle) -> None: + os.close(handle) + + def close(self) -> None: + self._socket.close() + + # TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -714,8 +815,6 @@ def __init__( int ] = [] # std::vector mUcHandles - self._shareable_handle_type = None - # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE @@ -741,6 +840,7 @@ def __init__( f"Signal pad offset: {self.signal_pad_offset}" ) + # Create handle exchanger based on multi-node mode if self.is_multi_node: # Check if fabric handle is supported fabric_handle_supported = checkCudaErrors( @@ -753,15 +853,12 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) - # Use fabric handle for multi-node NVLS - self._shareable_handle_type = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + self._exchanger: HandleExchanger = FabricHandleExchanger( + self.comm_backend, self.group_rank, self.group_size ) else: - self._init_ipc_socket() - # Use NVLink handle for single-node NVLS - self._shareable_handle_type = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._exchanger = PosixFDHandleExchanger( + self.comm_backend, self.group_rank, self.group_size ) self._alloc_mn_mcast_mem(buf_size) @@ -785,8 +882,8 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if hasattr(self, "_ipc_socket"): - self._ipc_socket.close() + if hasattr(self, "_exchanger"): + self._exchanger.close() # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. @@ -890,22 +987,23 @@ def get_usable_buffer_size(self) -> int: """Get the usable buffer size (excluding signal pad)""" return self.allocation_size - self.SIGNAL_PAD_SIZE - def _init_ipc_socket(self): - if self.group_rank == 0: - # Gnerate the opId - opId = random.randint(0, 2**64 - 1) - else: - opId = None - opId = self.comm_backend.bcast(opId, root=0) - self._ipc_socket = IpcSocket(self.group_rank, opId) - def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" + self._verify_cuda_context() + + # Compute allocation size and get allocation properties + allocation_prop, mc_prop = self._get_allocation_prop(buf_size) + + # Allocate, exchange, and map unicast buffers + self._allocate_unicast_buffers(allocation_prop) + + # Setup multicast object, exchange handles, map and bind memory + self._setup_multicast(mc_prop) - # Verify CUDA context + def _verify_cuda_context(self): + """Verify CUDA context is set to the correct device.""" try: current_device = checkCudaErrors(cuda.cuCtxGetDevice()) - if int(current_device) != self.device_idx: print( f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" @@ -913,16 +1011,16 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Set up allocation properties + def _get_allocation_prop(self, buf_size: int): + """Compute allocation size and return allocation/multicast properties.""" allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = self._shareable_handle_type + allocation_prop.requestedHandleTypes = self._exchanger.handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() allocation_prop.location.type = ( cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE ) allocation_prop.location.id = self.device_idx - allocation_prop.allocFlags.gpuDirectRDMACapable = 1 # Get allocation granularity @@ -933,7 +1031,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) - # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); self.allocation_size = round_up( buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity ) @@ -942,18 +1039,21 @@ def _alloc_mn_mcast_mem(self, buf_size: int): mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = self._shareable_handle_type + mc_prop.handleTypes = self._exchanger.handle_type - # Get multicast granularity - mc_granularity = checkCudaErrors( + # Get multicast granularity and adjust allocation size + self._mc_granularity = checkCudaErrors( cuda.cuMulticastGetGranularity( mc_prop, cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, ) ) + self.allocation_size = round_up(self.allocation_size, self._mc_granularity) - self.allocation_size = round_up(self.allocation_size, mc_granularity) + return allocation_prop, mc_prop + def _allocate_unicast_buffers(self, allocation_prop): + """Allocate local UC memory, exchange handles with peers, and map memory.""" # Initialize UC handles list self.uc_handles = [0] * self.group_size @@ -962,34 +1062,17 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) - # Export local handle to fabric handle or FD + # Export local handle to shareable handle local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - self._shareable_handle_type, + self._exchanger.handle_type, 0, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): - # All-gather fabric handles - all_shareable_uc_handles = self.comm_backend.allgather( - local_shareable_uc_handle.data - ) - else: - # Implement the allgather logic with ipc socket - all_shareable_uc_handles = [None] * self.group_size - for i in range(self.group_size): - self.comm_backend.barrier() - # Send to peer at offset i - dest_rank = (self.group_rank + i) % self.group_size - self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) - # Receive from peer at offset -i - src_rank = (self.group_rank + self.group_size - i) % self.group_size - all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + # All-gather shareable handles + all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle) cuda.cuCtxSynchronize() # Import remote handles @@ -998,117 +1081,81 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( all_shareable_uc_handles[p], - self._shareable_handle_type, + self._exchanger.handle_type, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR - ): - # Close FD after import - os.close(all_shareable_uc_handles[p]) + self._exchanger.cleanup(all_shareable_uc_handles[p]) + + # Reserve address space for UC pointers + self.uc_ptrs = [0] * self.group_size + total_uc_size = self.allocation_size * self.group_size + self.total_uc_size = total_uc_size + uc_base_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(total_uc_size, self._mc_granularity, 0, 0) + ) + self.uc_base_ptr = uc_base_ptr + + # Map UC memory + for i in range(self.group_size): + offset = self.allocation_size * i + self.uc_ptrs[i] = int(uc_base_ptr) + offset + checkCudaErrors( + cuda.cuMemMap( + self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 + ) + ) - # Initialize multicasting + # Set memory access permissions for UC + access_desc = self._get_mem_access_desc() + checkCudaErrors( + cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) + ) + + def _setup_multicast(self, mc_prop): + """Create multicast object, exchange handle, map memory, and bind.""" + # Rank 0 creates the multicast object if self.group_rank == 0: - # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - - # Export multicast handle, there's only one handle for the entire group shareable_mc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, - self._shareable_handle_type, + self._exchanger.handle_type, 0, ) ) else: shareable_mc_handle = None - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): - # Broadcast multicast handle - shareable_mc_handle = self.comm_backend.bcast( - shareable_mc_handle.data if shareable_mc_handle else None, root=0 - ) - else: - # Implement bcast logic with ipc socket - if self.group_rank == 0: - for p in range(1, self.group_size): - self.comm_backend.barrier() - self._ipc_socket.send_fd(shareable_mc_handle, p) - else: - # Other ranks receive from rank 0 - # We need to order the receive to avoid a race condition bug we encountered. If driver fixed this issue, the additional barriers used for ordering can be removed. - for _ in range(self.group_rank): - self.comm_backend.barrier() - shareable_mc_handle = self._ipc_socket.recv_fd() - for _ in range(self.group_size - self.group_rank - 1): - self.comm_backend.barrier() - # Sync device to ensure broadcast is complete + + # Broadcast multicast handle from rank 0 + shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0) cuda.cuCtxSynchronize() + # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( shareable_mc_handle, - self._shareable_handle_type, + self._exchanger.handle_type, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR - ): - # Close FD after import - os.close(shareable_mc_handle) + self._exchanger.cleanup(shareable_mc_handle) + # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) - # Bind memory addresses - self.uc_ptrs = [0] * self.group_size - - # Reserve address space for UC pointers - total_uc_size = self.allocation_size * self.group_size - self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) - ) - self.uc_base_ptr = uc_base_ptr # Store for cleanup - - # Set up memory access descriptor - access_desc = cuda.CUmemAccessDesc() - access_desc.location = cuda.CUmemLocation() - access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - access_desc.location.id = self.device_idx - access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - - # Map UC memory - for i in range(self.group_size): - offset = self.allocation_size * i - self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors( - cuda.cuMemMap( - self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 - ) - ) - - # Set memory access permissions - checkCudaErrors( - cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) - ) - - # Bind MC pointer + # Reserve and map MC pointer self.mc_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + cuda.cuMemAddressReserve(self.allocation_size, self._mc_granularity, 0, 0) ) checkCudaErrors( cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) ) + access_desc = self._get_mem_access_desc() checkCudaErrors( cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) ) - # Bind memory to multicast + # Bind local memory to multicast checkCudaErrors( cuda.cuMulticastBindMem( self.mc_handle, @@ -1120,6 +1167,15 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) + def _get_mem_access_desc(self): + """Create memory access descriptor for this device.""" + access_desc = cuda.CUmemAccessDesc() + access_desc.location = cuda.CUmemLocation() + access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc.location.id = self.device_idx + access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + return access_desc + def lamport_initialize(self, rank: int, dtype: torch.dtype): if dtype == torch.bfloat16 or dtype == torch.float16: neg_zero = 0x8000