Skip to content

Commit 467d2f2

Browse files
authored
Merge pull request #95 from pyiron/extract_variables
extract_vairable()
2 parents 170b183 + 6a339c1 commit 467d2f2

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

pylammpsmpi/mpi/lmpmpi.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,10 @@ def convert_data(val, type, length, width):
7777
val = job.extract_compute(*filtered_args)
7878
return convert_data(val=val, type=type, length=length, width=width)
7979
elif style == 1: # per atom property
80-
val = job.numpy.extract_compute(*filtered_args)
81-
val_gather = MPI.COMM_WORLD.gather(val, root=0)
80+
val = _gather_data_from_all_processors(
81+
data=job.numpy.extract_compute(*filtered_args)
82+
)
8283
if MPI.COMM_WORLD.rank == 0:
83-
# val_gather.shape [number of cores, atoms on specific core]
84-
# the number of atoms on specific cores can vary
85-
val = []
86-
for vl in val_gather:
87-
for v in vl:
88-
val.append(v)
8984
length = job.get_natoms()
9085
return convert_data(val=val, type=type, length=length, width=width)
9186
else: # Todo
@@ -165,15 +160,20 @@ def extract_fix(funct_args):
165160
def extract_variable(funct_args):
166161
# in the args - if the third one,
167162
# which is the type is 1 - a lammps array is returned
168-
if MPI.COMM_WORLD.rank == 0:
169-
# if type is 1 - reformat file
170-
try:
171-
data = job.extract_variable(*funct_args)
172-
except ValueError:
173-
return []
174-
if funct_args[2] == 1:
175-
data = np.array(data)
176-
return data
163+
if funct_args[2] == 1:
164+
data = _gather_data_from_all_processors(
165+
data=job.numpy.extract_variable(*funct_args)
166+
)
167+
if MPI.COMM_WORLD.rank == 0:
168+
return np.array(data)
169+
else:
170+
if MPI.COMM_WORLD.rank == 0:
171+
# if type is 1 - reformat file
172+
try:
173+
data = job.extract_variable(*funct_args)
174+
except ValueError:
175+
return []
176+
return data
177177

178178

179179
def get_natoms(funct_args):
@@ -472,6 +472,16 @@ def select_cmd(argument):
472472
return switcher.get(argument)
473473

474474

475+
def _gather_data_from_all_processors(data):
476+
data_gather = MPI.COMM_WORLD.gather(data, root=0)
477+
if MPI.COMM_WORLD.rank == 0:
478+
data = []
479+
for vl in data_gather:
480+
for v in vl:
481+
data.append(v)
482+
return data
483+
484+
475485
if __name__ == "__main__":
476486
while True:
477487
if MPI.COMM_WORLD.rank == 0:

tests/test_pylammpsmpi_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_extract_variable(self):
5959
self.assertEqual(np.round(x, 2), 1.13)
6060

6161
x = self.lmp.extract_variable("fx", "all", 1)
62-
self.assertEqual(len(x), 128)
62+
self.assertEqual(len(x), 256)
6363
self.assertEqual(np.round(x[0], 2), -0.26)
6464

6565
def test_scatter_atoms(self):

tests/test_pylammpsmpi_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_extract_variable(self):
5555
x = self.lmp.extract_variable("tt", "all", 0)
5656
self.assertEqual(np.round(x, 2), 1.13)
5757
x = self.lmp.extract_variable("fx", "all", 1)
58-
self.assertEqual(len(x), 128)
58+
self.assertEqual(len(x), 256)
5959
self.assertEqual(np.round(x[0], 2), -0.26)
6060

6161
def test_scatter_atoms(self):

0 commit comments

Comments
 (0)