@@ -87,16 +87,18 @@ class Border : public torch::autograd::Function<Border> {
8787 int mpi_init = 0 ;
8888 MPI_Initialized (&mpi_init);
8989 int cuda_aware = 1 ;
90- int me;
90+ int me = 0 ;
9191 MPI_Comm world;
9292 int world_size = 0 ;
93- unpack_communicator (communicator_tensor, world);
94- MPI_Comm_rank (world, &me);
95- MPI_Comm_size (world, &world_size);
93+ if (mpi_init) {
94+ unpack_communicator (communicator_tensor, world);
95+ MPI_Comm_rank (world, &me);
96+ MPI_Comm_size (world, &world_size);
97+ }
9698 MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
9799 MPI_Request request;
98100#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
99- if (world_size ! = 1 ) {
101+ if (world_size > = 1 ) {
100102 int version, subversion;
101103 MPI_Get_version (&version, &subversion);
102104 if (version >= 4 ) {
@@ -120,11 +122,15 @@ class Border : public torch::autograd::Function<Border> {
120122 for (int iswap = 0 ; iswap < nswap; ++iswap) {
121123 int nrecv = recvnum[iswap];
122124 int nsend = sendnum[iswap];
123- torch::Tensor isendlist =
124- torch::from_blob (sendlist[iswap], {nsend}, int32_options)
125- .to (recv_g1_tensor.device ());
126- torch::Tensor send_g1_tensor = recv_g1_tensor.index_select (0 , isendlist);
127- FPTYPE* send_g1 = send_g1_tensor.data_ptr <FPTYPE>();
125+ torch::Tensor isendlist;
126+ torch::Tensor send_g1_tensor;
127+ FPTYPE* send_g1;
128+ if (nsend != 0 ) {
129+ isendlist = torch::from_blob (sendlist[iswap], {nsend}, int32_options)
130+ .to (recv_g1_tensor.device ());
131+ send_g1_tensor = recv_g1_tensor.index_select (0 , isendlist);
132+ send_g1 = send_g1_tensor.data_ptr <FPTYPE>();
133+ }
128134#ifdef USE_MPI
129135 if (sendproc[iswap] != me) {
130136 if (nrecv) {
@@ -207,15 +213,17 @@ class Border : public torch::autograd::Function<Border> {
207213 MPI_Initialized (&mpi_init);
208214 int world_size = 0 ;
209215 int cuda_aware = 1 ;
216+ int me = 0 ;
210217 MPI_Comm world;
211- unpack_communicator (communicator_tensor, world);
212- int me;
213- MPI_Comm_rank (world, &me);
214- MPI_Comm_size (world, &world_size);
218+ if (mpi_init) {
219+ unpack_communicator (communicator_tensor, world);
220+ MPI_Comm_rank (world, &me);
221+ MPI_Comm_size (world, &world_size);
222+ }
215223 MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
216224 MPI_Request request;
217225#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
218- if (world_size ! = 1 ) {
226+ if (world_size > = 1 ) {
219227 int version, subversion;
220228 MPI_Get_version (&version, &subversion);
221229 if (version >= 4 ) {
@@ -248,17 +256,20 @@ class Border : public torch::autograd::Function<Border> {
248256 int nlocal = nlocal_tensor.item <int >();
249257 int nghost = nghost_tensor.item <int >();
250258 int ntotal = nlocal + nghost;
251-
252- torch::Tensor send_g1_tensor = d_local_g1_tensor;
253-
254- int max_recvnum = sendnum_tensor.max ().item <int >();
255- auto options = torch::TensorOptions ()
256- .dtype (d_local_g1_tensor.dtype ())
257- .device (d_local_g1_tensor.device ());
258- torch::Tensor recv_g1_tensor =
259- torch::empty ({max_recvnum, tensor_size}, options);
260- FPTYPE* recv_g1 = recv_g1_tensor.data_ptr <FPTYPE>();
261- FPTYPE* send_g1 = send_g1_tensor.data_ptr <FPTYPE>() + ntotal * tensor_size;
259+ torch::Tensor send_g1_tensor;
260+ torch::Tensor recv_g1_tensor;
261+ FPTYPE* recv_g1;
262+ FPTYPE* send_g1;
263+ if (nswap != 0 ) {
264+ send_g1_tensor = d_local_g1_tensor;
265+ int max_recvnum = sendnum_tensor.max ().item <int >();
266+ auto options = torch::TensorOptions ()
267+ .dtype (d_local_g1_tensor.dtype ())
268+ .device (d_local_g1_tensor.device ());
269+ recv_g1_tensor = torch::empty ({max_recvnum, tensor_size}, options);
270+ recv_g1 = recv_g1_tensor.data_ptr <FPTYPE>();
271+ send_g1 = send_g1_tensor.data_ptr <FPTYPE>() + ntotal * tensor_size;
272+ }
262273
263274 int end = ntotal;
264275 auto int32_options = torch::TensorOptions ().dtype (torch::kInt32 );
0 commit comments