@@ -93,6 +93,12 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t);
9393
9494// TODO: include NCCL headers
9595#include < nccl.h>
96+ #ifdef NCCL_VERSION
97+ #define NCCL_VERSION_UB NCCL_VERSION (2 ,19 ,1 )
98+ #define NCCL_UB_SUPPORT NCCL_VERSION_CODE >= NCCL_VERSION_UB
99+ #else
100+ #define NCCL_UB_SUPPORT 0
101+ #endif
96102
97103#define NCCL_CALL (call ) \
98104 { \
@@ -168,7 +174,13 @@ int main(int argc, char* argv[]) {
168174 const int nx = get_argval<int >(argv, argv + argc, " -nx" , 16384 );
169175 const int ny = get_argval<int >(argv, argv + argc, " -ny" , 16384 );
170176 const bool csv = get_arg (argv, argv + argc, " -csv" );
171-
177+ bool user_buffer_reg = get_arg (argv, argv + argc, " -user_buffer_reg" );
178+ #if NCCL_UB_SUPPORT == 0
179+ if (user_buffer_reg) {
180+ fprintf (stderr," WARNING: Ignoring -user_buffer_reg, required NCCL APIs are provided by NCCL 2.19.1 or later.\n " );
181+ user_buffer_reg = false ;
182+ }
183+ #endif // NCCL_UB_SUPPORT == 0
172184 int local_rank = -1 ;
173185 {
174186 MPI_Comm local_comm;
@@ -220,10 +232,27 @@ int main(int argc, char* argv[]) {
220232 chunk_size = chunk_size_high;
221233
222234 real* a;
223- CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
224235 real* a_new;
225- CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
226236
237+ #if NCCL_UB_SUPPORT
238+ void * a_reg_handle;
239+ void * a_new_reg_handle;
240+ if (user_buffer_reg) {
241+ // TODO: Allocate the memory with ncclMemAlloc and register it for the commmunicatior
242+ NCCL_CALL (ncclMemAlloc ( (void **) &a , nx * (chunk_size + 2 ) * sizeof (real)));
243+ NCCL_CALL (ncclMemAlloc ( (void **) &a_new, nx * (chunk_size + 2 ) * sizeof (real)));
244+ NCCL_CALL (ncclCommRegister (nccl_comm, a , nx * (chunk_size + 2 ) * sizeof (real), &a_reg_handle));
245+ NCCL_CALL (ncclCommRegister (nccl_comm, a_new, nx * (chunk_size + 2 ) * sizeof (real), &a_new_reg_handle));
246+ if ( nccl_version < 22304 ) {
247+ fprintf (stderr," WARNING: -user_buffer_reg available, but Jacobi communication pattern needs NCCL 2.23.4 or later.\n " );
248+ }
249+ }
250+ else
251+ #endif // NCCL_UB_SUPPORT
252+ {
253+ CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
254+ CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
255+ }
227256 CUDA_RT_CALL (cudaMemset (a, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
228257 CUDA_RT_CALL (cudaMemset (a_new, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
229258
@@ -403,10 +432,20 @@ int main(int argc, char* argv[]) {
403432
404433 CUDA_RT_CALL (cudaFreeHost (l2_norm_h));
405434 CUDA_RT_CALL (cudaFree (l2_norm_d));
406-
435+ #if NCCL_UB_SUPPORT
436+ if (user_buffer_reg) {
437+ // TODO: Deregister and Free the Buffer
438+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_new_reg_handle));
439+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_reg_handle));
440+ NCCL_CALL (ncclMemFree (a_new));
441+ NCCL_CALL (ncclMemFree (a));
442+ }
443+ else
444+ #endif // NCCL_UB_SUPPORT
445+ {
407446 CUDA_RT_CALL (cudaFree (a_new));
408447 CUDA_RT_CALL (cudaFree (a));
409-
448+ }
410449 CUDA_RT_CALL (cudaFreeHost (a_h));
411450 CUDA_RT_CALL (cudaFreeHost (a_ref_h));
412451
0 commit comments