@@ -172,8 +172,7 @@ def emd(
172172 center_dual = True ,
173173 numThreads = 1 ,
174174 check_marginals = True ,
175- checkpoint = None ,
176- return_checkpoint = False ,
175+ warm_start = False ,
177176):
178177 r"""Solves the Earth Movers distance problem and returns the OT matrix
179178
@@ -234,15 +233,10 @@ def emd(
234233 check_marginals: bool, optional (default=True)
235234 If True, checks that the marginals mass are equal. If False, skips the
236235 check.
237- checkpoint: dict, optional (default=None)
238- Checkpoint data from a previous emd() call to resume computation.
239- The checkpoint must contain internal solver state including flow,
240- potentials, and tree structure. Obtain by calling emd() with
241- return_checkpoint=True.
242- return_checkpoint: bool, optional (default=False)
243- If True and log=True, includes complete internal solver state in the
244- returned log dictionary for checkpointing. This enables pausing and
245- resuming the optimization.
236+ warm_start: bool or dict, optional (default=False)
237+ If True, returns warm start data in the log for resuming computation.
238+ If dict (from previous call with warm_start=True), resumes optimization
239+ from the provided state. Requires log=True when saving state.
246240
247241
248242 Returns
@@ -252,7 +246,7 @@ def emd(
252246 parameters
253247 log: dict, optional
254248 If input log is true, a dictionary containing the
255- cost and dual variables and exit status. If return_checkpoint =True,
249+ cost and dual variables and exit status. If warm_start =True,
256250 also contains internal solver state for resuming computation.
257251
258252
@@ -270,6 +264,13 @@ def emd(
270264 array([[0.5, 0. ],
271265 [0. , 0.5]])
272266
267+ Warm start example for resuming optimization:
268+
269+ >>> # First call - save warm start data
270+ >>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True)
271+ >>> # Resume from warm start
272+ >>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log)
273+
273274
274275 .. _references-emd:
275276 References
@@ -333,39 +334,67 @@ def emd(
333334
334335 numThreads = check_number_threads (numThreads )
335336
337+ # Handle warm_start parameter
336338 checkpoint_data = None
337- if checkpoint is not None :
338- # Extract checkpoint arrays and convert to numpy (strip leading underscore)
339+ return_checkpoint = False
340+
341+ if isinstance (warm_start , dict ):
342+ # Resume from previous warm_start dict
339343 checkpoint_data = {
340- "flow" : nx .to_numpy (checkpoint ["_flow" ]) if "_flow" in checkpoint else None ,
341- "pi" : nx .to_numpy (checkpoint ["_pi" ]) if "_pi" in checkpoint else None ,
342- "state" : nx .to_numpy (checkpoint ["_state" ])
343- if "_state" in checkpoint
344+ "flow" : nx .to_numpy (warm_start .get ("_flow" , warm_start .get ("flow" )))
345+ if ("_flow" in warm_start or "flow" in warm_start )
344346 else None ,
345- "parent " : nx .to_numpy (checkpoint [ "_parent" ] )
346- if "_parent " in checkpoint
347+ "pi " : nx .to_numpy (warm_start . get ( "_pi" , warm_start . get ( "pi" )) )
348+ if ( "_pi " in warm_start or "pi" in warm_start )
347349 else None ,
348- "pred" : nx .to_numpy (checkpoint ["_pred" ]) if "_pred" in checkpoint else None ,
349- "thread" : nx .to_numpy (checkpoint ["_thread" ])
350- if "_thread" in checkpoint
350+ "state" : nx .to_numpy (warm_start .get ("_state" , warm_start .get ("state" )))
351+ if ("_state" in warm_start or "state" in warm_start )
351352 else None ,
352- "rev_thread " : nx .to_numpy (checkpoint [ "_rev_thread" ] )
353- if "_rev_thread " in checkpoint
353+ "parent " : nx .to_numpy (warm_start . get ( "_parent" , warm_start . get ( "parent" )) )
354+ if ( "_parent " in warm_start or "parent" in warm_start )
354355 else None ,
355- "succ_num " : nx .to_numpy (checkpoint [ "_succ_num" ] )
356- if "_succ_num " in checkpoint
356+ "pred " : nx .to_numpy (warm_start . get ( "_pred" , warm_start . get ( "pred" )) )
357+ if ( "_pred " in warm_start or "pred" in warm_start )
357358 else None ,
358- "last_succ " : nx .to_numpy (checkpoint [ "_last_succ" ] )
359- if "_last_succ " in checkpoint
359+ "thread " : nx .to_numpy (warm_start . get ( "_thread" , warm_start . get ( "thread" )) )
360+ if ( "_thread " in warm_start or "thread" in warm_start )
360361 else None ,
361- "forward" : nx .to_numpy (checkpoint ["_forward" ])
362- if "_forward" in checkpoint
362+ "rev_thread" : nx .to_numpy (
363+ warm_start .get ("_rev_thread" , warm_start .get ("rev_thread" ))
364+ )
365+ if ("_rev_thread" in warm_start or "rev_thread" in warm_start )
366+ else None ,
367+ "succ_num" : nx .to_numpy (
368+ warm_start .get ("_succ_num" , warm_start .get ("succ_num" ))
369+ )
370+ if ("_succ_num" in warm_start or "succ_num" in warm_start )
363371 else None ,
364- "search_arc_num" : int (checkpoint .get ("search_arc_num" , 0 )),
365- "all_arc_num" : int (checkpoint .get ("all_arc_num" , 0 )),
372+ "last_succ" : nx .to_numpy (
373+ warm_start .get ("_last_succ" , warm_start .get ("last_succ" ))
374+ )
375+ if ("_last_succ" in warm_start or "last_succ" in warm_start )
376+ else None ,
377+ "forward" : nx .to_numpy (
378+ warm_start .get ("_forward" , warm_start .get ("forward" ))
379+ )
380+ if ("_forward" in warm_start or "forward" in warm_start )
381+ else None ,
382+ "search_arc_num" : int (
383+ warm_start .get ("search_arc_num" , warm_start .get ("_search_arc_num" , 0 ))
384+ ),
385+ "all_arc_num" : int (
386+ warm_start .get ("all_arc_num" , warm_start .get ("_all_arc_num" , 0 ))
387+ ),
366388 }
367389 # Filter out None values
368390 checkpoint_data = {k : v for k , v in checkpoint_data .items () if v is not None }
391+ elif warm_start is True :
392+ # Save warm_start data - requires log=True
393+ if not log :
394+ raise ValueError (
395+ "warm_start=True requires log=True to return the warm start data"
396+ )
397+ return_checkpoint = True
369398
370399 G , cost , u , v , result_code , checkpoint_out = emd_c (
371400 a , b , M , numItermax , numThreads , checkpoint_data , int (return_checkpoint )
0 commit comments