Skip to content
125 changes: 69 additions & 56 deletions csrc/trtllm_mnnvl_allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,77 +26,90 @@ using tvm::ffi::Optional;
} \
}()

void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev,
int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks,
int64_t rank, bool wait_for_results, bool launch_with_pdl,
Optional<TensorView> out) {
cudaSetDevice(in.device().device_id);
auto stream = get_stream(in.device());
void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr,
int64_t buffer_ptrs_dev, int64_t buffer_ptr_local,
TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank,
bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot,
TensorView output, Optional<TensorView> residual_out,
Optional<TensorView> residual_in, Optional<TensorView> gamma,
Optional<double> epsilon) {
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] {
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);
int64_t num_tokens = input.size(0);
int64_t token_dim = input.size(1);

// Validate input parameters
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type);
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type);
TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1))
<< "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << output.size(0) << ", " << output.size(1) << ")";
TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64)
<< "nranks must be between 2 and 64, got " << nranks;
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
<< "rank must be between 0 and nranks-1, got " << rank;
TVM_FFI_ICHECK(out.has_value() || !wait_for_results)
<< "out tensor must be provided if wait_for_results is true";
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
epsilon.has_value()) ||
!rmsnorm_fusion)
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
"true";

if (rmsnorm_fusion) {
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
residual_in.value().size(1) == token_dim)
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
<< ")";
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
residual_out.value().size(1) == token_dim)
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
<< ")";
TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
<< "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
<< gamma.value().size(0) << ")";
}

// Create the parameters struct
AllReduceParams<c_type> params;
params.nranks = nranks;
params.rank = rank;
params.buffer_M = buffer_M;
params.num_tokens = num_tokens;
params.token_dim = token_dim;
params.buffer_ptrs_dev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.multicast_ptr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.buffer_flags = buffer_flags_mnnvl.data_ptr();
params.wait_for_results = wait_for_results;
params.launch_with_pdl = launch_with_pdl;
params.input = in.data_ptr();
params.output = out.has_value() ? out.value().data_ptr() : nullptr;
params.stream = stream;
AllReduceFusionParams params;

auto status = twoshot_allreduce_dispatch_world_size<c_type>(params);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_allreduce_dispatch_world_size failed with error code "
<< cudaGetErrorString(status);
});
}
// Aux Information
params.nRanks = nranks;
params.rank = rank;
params.numTokens = num_tokens;
params.tokenDim = token_dim;
params.bufferPtrsDev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.bufferPtrLocal = reinterpret_cast<void*>(buffer_ptr_local);
params.multicastPtr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.bufferFlags = reinterpret_cast<uint32_t*>(buffer_flags_mnnvl.data_ptr());
params.rmsNormFusion = rmsnorm_fusion;
params.launchWithPdl = launch_with_pdl;

void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output,
TensorView normed_output, TensorView gamma, double epsilon,
TensorView residual, TensorView buffer_flags, bool launch_with_pdl) {
cudaSetDevice(prenorm_output.device().device_id);
auto stream = get_stream(prenorm_output.device());
// input data
params.input = const_cast<void const*>(input.data_ptr());
params.residualIn =
residual_in.has_value() ? const_cast<void const*>(residual_in.value().data_ptr()) : nullptr;
params.gamma = gamma.has_value() ? const_cast<void const*>(gamma.value().data_ptr()) : nullptr;
params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] {
// Create the parameters struct
RMSNormParams<c_type> params;
params.residual_output = prenorm_output.data_ptr();
params.output = normed_output.data_ptr();
params.input = reinterpret_cast<void const*>(multicast_buffer_ptr);
params.gamma = gamma.data_ptr();
params.epsilon = epsilon;
params.residual = residual.data_ptr();
params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
params.batch = normed_output.size(0);
params.hidden_dim = normed_output.size(1);
// output data
params.output = const_cast<void*>(output.data_ptr());
params.residualOut =
residual_out.has_value() ? const_cast<void*>(residual_out.value().data_ptr()) : nullptr;
params.stream = stream;
params.launch_with_pdl = launch_with_pdl;
auto status = twoshot_rmsnorm_dispatch_hidden_dim<c_type>(params);

cudaError_t status;
if (use_oneshot) {
status = oneshotAllreduceFusionDispatch<c_type>(params);
} else {
status = twoshotAllreduceFusionDispatch<c_type>(params);
}
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_rmsnorm_dispatch_hidden_dim failed with error code "
<< cudaGetErrorString(status);
<< "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status);
});
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion);
Loading