@@ -247,7 +247,8 @@ def emd(
247247 log: dict, optional
248248 If input log is true, a dictionary containing the
249249 cost and dual variables and exit status. If warm_start=True,
250- also contains internal solver state for resuming computation.
250+ also contains a "checkpoint" key with the internal solver state
251+ for resuming computation.
251252
252253
253254 Examples
@@ -268,6 +269,7 @@ def emd(
268269
269270 >>> # First call - save warm start data
270271 >>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True)
272+ >>> # log["checkpoint"] contains the solver state
271273 >>> # Resume from warm start
272274 >>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log)
273275
@@ -340,51 +342,47 @@ def emd(
340342
341343 if isinstance (warm_start , dict ):
342344 # Resume from previous warm_start dict
345+ # Check if checkpoint is nested under "checkpoint" key or at top level
346+ if "checkpoint" in warm_start :
347+ chkpt = warm_start ["checkpoint" ]
348+ else :
349+ chkpt = warm_start
350+
343351 checkpoint_data = {
344- "flow" : nx .to_numpy (warm_start .get ("_flow " , warm_start .get ("flow " )))
345- if ("_flow " in warm_start or "flow " in warm_start )
352+ "flow" : nx .to_numpy (chkpt .get ("flow " , chkpt .get ("_flow " )))
353+ if ("flow " in chkpt or "_flow " in chkpt )
346354 else None ,
347- "pi" : nx .to_numpy (warm_start .get ("_pi " , warm_start .get ("pi " )))
348- if ("_pi " in warm_start or "pi " in warm_start )
355+ "pi" : nx .to_numpy (chkpt .get ("pi " , chkpt .get ("_pi " )))
356+ if ("pi " in chkpt or "_pi " in chkpt )
349357 else None ,
350- "state" : nx .to_numpy (warm_start .get ("_state " , warm_start .get ("state " )))
351- if ("_state " in warm_start or "state " in warm_start )
358+ "state" : nx .to_numpy (chkpt .get ("state " , chkpt .get ("_state " )))
359+ if ("state " in chkpt or "_state " in chkpt )
352360 else None ,
353- "parent" : nx .to_numpy (warm_start .get ("_parent " , warm_start .get ("parent " )))
354- if ("_parent " in warm_start or "parent " in warm_start )
361+ "parent" : nx .to_numpy (chkpt .get ("parent " , chkpt .get ("_parent " )))
362+ if ("parent " in chkpt or "_parent " in chkpt )
355363 else None ,
356- "pred" : nx .to_numpy (warm_start .get ("_pred " , warm_start .get ("pred " )))
357- if ("_pred " in warm_start or "pred " in warm_start )
364+ "pred" : nx .to_numpy (chkpt .get ("pred " , chkpt .get ("_pred " )))
365+ if ("pred " in chkpt or "_pred " in chkpt )
358366 else None ,
359- "thread" : nx .to_numpy (warm_start .get ("_thread " , warm_start .get ("thread " )))
360- if ("_thread " in warm_start or "thread " in warm_start )
367+ "thread" : nx .to_numpy (chkpt .get ("thread " , chkpt .get ("_thread " )))
368+ if ("thread " in chkpt or "_thread " in chkpt )
361369 else None ,
362- "rev_thread" : nx .to_numpy (
363- warm_start .get ("_rev_thread" , warm_start .get ("rev_thread" ))
364- )
365- if ("_rev_thread" in warm_start or "rev_thread" in warm_start )
370+ "rev_thread" : nx .to_numpy (chkpt .get ("rev_thread" , chkpt .get ("_rev_thread" )))
371+ if ("rev_thread" in chkpt or "_rev_thread" in chkpt )
366372 else None ,
367- "succ_num" : nx .to_numpy (
368- warm_start .get ("_succ_num" , warm_start .get ("succ_num" ))
369- )
370- if ("_succ_num" in warm_start or "succ_num" in warm_start )
373+ "succ_num" : nx .to_numpy (chkpt .get ("succ_num" , chkpt .get ("_succ_num" )))
374+ if ("succ_num" in chkpt or "_succ_num" in chkpt )
371375 else None ,
372- "last_succ" : nx .to_numpy (
373- warm_start .get ("_last_succ" , warm_start .get ("last_succ" ))
374- )
375- if ("_last_succ" in warm_start or "last_succ" in warm_start )
376+ "last_succ" : nx .to_numpy (chkpt .get ("last_succ" , chkpt .get ("_last_succ" )))
377+ if ("last_succ" in chkpt or "_last_succ" in chkpt )
376378 else None ,
377- "forward" : nx .to_numpy (
378- warm_start .get ("_forward" , warm_start .get ("forward" ))
379- )
380- if ("_forward" in warm_start or "forward" in warm_start )
379+ "forward" : nx .to_numpy (chkpt .get ("forward" , chkpt .get ("_forward" )))
380+ if ("forward" in chkpt or "_forward" in chkpt )
381381 else None ,
382382 "search_arc_num" : int (
383- warm_start .get ("search_arc_num" , warm_start .get ("_search_arc_num" , 0 ))
384- ),
385- "all_arc_num" : int (
386- warm_start .get ("all_arc_num" , warm_start .get ("_all_arc_num" , 0 ))
383+ chkpt .get ("search_arc_num" , chkpt .get ("_search_arc_num" , 0 ))
387384 ),
385+ "all_arc_num" : int (chkpt .get ("all_arc_num" , chkpt .get ("_all_arc_num" , 0 ))),
388386 }
389387 # Filter out None values
390388 checkpoint_data = {k : v for k , v in checkpoint_data .items () if v is not None }
@@ -425,18 +423,20 @@ def emd(
425423
426424 # Add checkpoint data if requested (preserve original dtypes, don't cast)
427425 if return_checkpoint and checkpoint_out is not None :
428- log ["_flow" ] = checkpoint_out ["flow" ]
429- log ["_pi" ] = checkpoint_out ["pi" ]
430- log ["_state" ] = checkpoint_out ["state" ]
431- log ["_parent" ] = checkpoint_out ["parent" ]
432- log ["_pred" ] = checkpoint_out ["pred" ]
433- log ["_thread" ] = checkpoint_out ["thread" ]
434- log ["_rev_thread" ] = checkpoint_out ["rev_thread" ]
435- log ["_succ_num" ] = checkpoint_out ["succ_num" ]
436- log ["_last_succ" ] = checkpoint_out ["last_succ" ]
437- log ["_forward" ] = checkpoint_out ["forward" ]
438- log ["search_arc_num" ] = int (checkpoint_out ["search_arc_num" ])
439- log ["all_arc_num" ] = int (checkpoint_out ["all_arc_num" ])
426+ log ["checkpoint" ] = {
427+ "flow" : checkpoint_out ["flow" ],
428+ "pi" : checkpoint_out ["pi" ],
429+ "state" : checkpoint_out ["state" ],
430+ "parent" : checkpoint_out ["parent" ],
431+ "pred" : checkpoint_out ["pred" ],
432+ "thread" : checkpoint_out ["thread" ],
433+ "rev_thread" : checkpoint_out ["rev_thread" ],
434+ "succ_num" : checkpoint_out ["succ_num" ],
435+ "last_succ" : checkpoint_out ["last_succ" ],
436+ "forward" : checkpoint_out ["forward" ],
437+ "search_arc_num" : int (checkpoint_out ["search_arc_num" ]),
438+ "all_arc_num" : int (checkpoint_out ["all_arc_num" ]),
439+ }
440440
441441 return nx .from_numpy (G , type_as = type_as ), log
442442 return nx .from_numpy (G , type_as = type_as )
0 commit comments