@@ -360,6 +360,16 @@ def func(_x):
360360 if _type is None :
361361 [ret .append (func (x )) for x in X ]
362362
363+ elif _type is "starmap" :
364+
365+ if len (_params ) != 1 :
366+ raise Exception ("The starmap parallelization method must be accompanied by a starmapping callable" )
367+
368+ params = [[X [k ], calc_gradient , self ._evaluate , args , kwargs ] for k in range (len (X ))]
369+
370+ starmapper = _params [0 ]
371+ ret = list (starmapper (evaluate_in_parallel , params ))
372+
363373 elif _type == "threads" :
364374
365375 if len (_params ) == 0 :
@@ -368,11 +378,9 @@ def func(_x):
368378 n_threads = _params [0 ]
369379
370380 with ThreadPool (n_threads ) as pool :
371- params = []
372- for k in range (len (X )):
373- params .append ([X [k ], calc_gradient , self ._evaluate , args , kwargs ])
381+ params = [[X [k ], calc_gradient , self ._evaluate , args , kwargs ] for k in range (len (X ))]
374382
375- ret = np . array ( pool .starmap (evaluate_in_parallel , params ) )
383+ ret = pool .starmap (evaluate_in_parallel , params )
376384
377385 elif _type == "dask" :
378386
@@ -389,7 +397,7 @@ def func(_x):
389397 ret = [job .result () for job in jobs ]
390398
391399 else :
392- raise Exception ("Unknown parallelization method: %s (None, threads, dask)" % self . parallelization )
400+ raise Exception ("Unknown parallelization method: %s (should be one of: None, starmap, threads, dask)" % _type )
393401
394402 # stack all the single outputs together
395403 for key in ret [0 ].keys ():
@@ -449,6 +457,10 @@ def calc_constraint_violation(G):
449457 else :
450458 return np .sum (G * (G > 0 ).astype (np .float ), axis = 1 )[:, None ]
451459
460+ def __getstate__ (self ):
461+ state = self .__dict__ .copy ()
462+ state ["parallelization" ] = None
463+ return state
452464
453465# makes all the output at least 2-d dimensional
454466def at_least2d (d ):
0 commit comments