Skip to content

Commit 79ac869

Browse files
authored
Merge pull request #46 from JoshKarpel/master
Add "starmap" parallel execution method
2 parents 0d03596 + f6d6931 commit 79ac869

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

pymoo/model/problem.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,16 @@ def func(_x):
360360
if _type is None:
361361
[ret.append(func(x)) for x in X]
362362

363+
elif _type is "starmap":
364+
365+
if len(_params) != 1:
366+
raise Exception("The starmap parallelization method must be accompanied by a starmapping callable")
367+
368+
params = [[X[k], calc_gradient, self._evaluate, args, kwargs] for k in range(len(X))]
369+
370+
starmapper = _params[0]
371+
ret = list(starmapper(evaluate_in_parallel, params))
372+
363373
elif _type == "threads":
364374

365375
if len(_params) == 0:
@@ -368,11 +378,9 @@ def func(_x):
368378
n_threads = _params[0]
369379

370380
with ThreadPool(n_threads) as pool:
371-
params = []
372-
for k in range(len(X)):
373-
params.append([X[k], calc_gradient, self._evaluate, args, kwargs])
381+
params = [[X[k], calc_gradient, self._evaluate, args, kwargs] for k in range(len(X))]
374382

375-
ret = np.array(pool.starmap(evaluate_in_parallel, params))
383+
ret = pool.starmap(evaluate_in_parallel, params)
376384

377385
elif _type == "dask":
378386

@@ -389,7 +397,7 @@ def func(_x):
389397
ret = [job.result() for job in jobs]
390398

391399
else:
392-
raise Exception("Unknown parallelization method: %s (None, threads, dask)" % self.parallelization)
400+
raise Exception("Unknown parallelization method: %s (should be one of: None, starmap, threads, dask)" % _type)
393401

394402
# stack all the single outputs together
395403
for key in ret[0].keys():
@@ -449,6 +457,10 @@ def calc_constraint_violation(G):
449457
else:
450458
return np.sum(G * (G > 0).astype(np.float), axis=1)[:, None]
451459

460+
def __getstate__(self):
461+
state = self.__dict__.copy()
462+
state["parallelization"] = None
463+
return state
452464

453465
# makes all the output at least 2-d dimensional
454466
def at_least2d(d):

tests/problems/test_parallel.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
22

3+
import multiprocessing
4+
35
import numpy as np
46

57
from pymoo.model.problem import Problem
@@ -23,6 +25,18 @@ def test_evaluation_in_threads_number(self):
2325
_F = MyProblemElementwise(parallelization=("threads", 2)).evaluate(X)
2426
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
2527

28+
def test_evaluation_with_multiprocessing_process_pool_starmap(self):
29+
X, F = self.get_data()
30+
with multiprocessing.Pool() as pool:
31+
_F = MyProblemElementwise(parallelization = ("starmap", pool.starmap)).evaluate(X)
32+
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
33+
34+
def test_evaluation_with_multiprocessing_thread_pool_starmap(self):
35+
X, F = self.get_data()
36+
with multiprocessing.pool.ThreadPool() as pool:
37+
_F = MyProblemElementwise(parallelization = ("starmap", pool.starmap)).evaluate(X)
38+
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
39+
2640

2741
class MyProblemElementwise(Problem):
2842

0 commit comments

Comments
 (0)