Skip to content

Commit 87a7c16

Browse files
committed
add "starmap" parallel execution method, with tests
1 parent 0d03596 commit 87a7c16

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

pymoo/model/problem.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,16 @@ def func(_x):
360360
if _type is None:
361361
[ret.append(func(x)) for x in X]
362362

363+
elif _type is "starmap":
364+
365+
if len(_params) != 1:
366+
raise Exception("The starmap parallelization method must be accompanied by a starmapping callable")
367+
368+
params = [[X[k], calc_gradient, self._evaluate, args, kwargs] for k in range(len(X))]
369+
370+
starmapper = _params[0]
371+
ret = starmapper(evaluate_in_parallel, params)
372+
363373
elif _type == "threads":
364374

365375
if len(_params) == 0:
@@ -368,11 +378,9 @@ def func(_x):
368378
n_threads = _params[0]
369379

370380
with ThreadPool(n_threads) as pool:
371-
params = []
372-
for k in range(len(X)):
373-
params.append([X[k], calc_gradient, self._evaluate, args, kwargs])
381+
params = [[X[k], calc_gradient, self._evaluate, args, kwargs] for k in range(len(X))]
374382

375-
ret = np.array(pool.starmap(evaluate_in_parallel, params))
383+
ret = pool.starmap(evaluate_in_parallel, params)
376384

377385
elif _type == "dask":
378386

@@ -389,7 +397,7 @@ def func(_x):
389397
ret = [job.result() for job in jobs]
390398

391399
else:
392-
raise Exception("Unknown parallelization method: %s (None, threads, dask)" % self.parallelization)
400+
raise Exception("Unknown parallelization method: %s (should be one of: None, starmap, threads, dask)" % _type)
393401

394402
# stack all the single outputs together
395403
for key in ret[0].keys():
@@ -449,6 +457,10 @@ def calc_constraint_violation(G):
449457
else:
450458
return np.sum(G * (G > 0).astype(np.float), axis=1)[:, None]
451459

460+
def __getstate__(self):
461+
state = self.__dict__.copy()
462+
state["parallelization"] = None
463+
return state
452464

453465
# makes all the output at least 2-d dimensional
454466
def at_least2d(d):

test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import autograd.numpy as anp
2+
import numpy as np
3+
from pymoo.algorithms.nsga2 import NSGA2
4+
from pymoo.factory import get_sampling, get_crossover, get_mutation, get_termination
5+
from pymoo.optimize import minimize
6+
from pymoo.util.misc import stack
7+
from pymoo.model.problem import Problem
8+
9+
import multiprocessing
10+
11+
class MyProblem(Problem):
12+
def __init__(self, **kwargs):
13+
super().__init__(n_var = 10, n_obj = 1, n_constr = 0, xl = -5, xu = 5,
14+
**kwargs)
15+
16+
def _evaluate(self, x, out, *args, **kwargs):
17+
out["F"] = (x ** 2).sum(axis = -1)
18+
19+
with multiprocessing.Pool() as pool:
20+
problem = MyProblem(elementwise_evaluation = True, parallelization = ('starmap', pool.starmap))
21+
r = np.random.RandomState(seed = 1)
22+
X = r.random((5, problem.n_var))
23+
result = problem.evaluate(X).squeeze()
24+
expected = (X ** 2).sum(axis = -1)
25+
26+
print(result)
27+
print(expected)
28+
29+
assert np.all(result == expected)
30+
print('yay')

tests/problems/test_parallel.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import unittest
22

3+
import multiprocessing
4+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
5+
36
import numpy as np
47

58
from pymoo.model.problem import Problem
@@ -23,6 +26,18 @@ def test_evaluation_in_threads_number(self):
2326
_F = MyProblemElementwise(parallelization=("threads", 2)).evaluate(X)
2427
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
2528

29+
def test_evaluation_with_multiprocessing_process_pool_starmap(self):
30+
X, F = self.get_data()
31+
with multiprocessing.Pool() as pool:
32+
_F = MyProblemElementwise(parallelization = ("starmap", pool.starmap)).evaluate(X)
33+
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
34+
35+
def test_evaluation_with_multiprocessing_thread_pool_starmap(self):
36+
X, F = self.get_data()
37+
with multiprocessing.pool.ThreadPool() as pool:
38+
_F = MyProblemElementwise(parallelization = ("starmap", pool.starmap)).evaluate(X)
39+
self.assertTrue(np.all(np.abs(_F - F) < 0.00001))
40+
2641

2742
class MyProblemElementwise(Problem):
2843

0 commit comments

Comments
 (0)