Skip to content

Commit 31a382d

Browse files
committed
Add checkpoint/resume functionality to EMD solver : added parameter declaration to only have one now
1 parent 76d6283 commit 31a382d

File tree

2 files changed

+88
-51
lines changed

2 files changed

+88
-51
lines changed

ot/lp/_network_simplex.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def emd(
172172
center_dual=True,
173173
numThreads=1,
174174
check_marginals=True,
175-
checkpoint=None,
176-
return_checkpoint=False,
175+
warm_start=False,
177176
):
178177
r"""Solves the Earth Movers distance problem and returns the OT matrix
179178
@@ -234,15 +233,10 @@ def emd(
234233
check_marginals: bool, optional (default=True)
235234
If True, checks that the marginals mass are equal. If False, skips the
236235
check.
237-
checkpoint: dict, optional (default=None)
238-
Checkpoint data from a previous emd() call to resume computation.
239-
The checkpoint must contain internal solver state including flow,
240-
potentials, and tree structure. Obtain by calling emd() with
241-
return_checkpoint=True.
242-
return_checkpoint: bool, optional (default=False)
243-
If True and log=True, includes complete internal solver state in the
244-
returned log dictionary for checkpointing. This enables pausing and
245-
resuming the optimization.
236+
warm_start: bool or dict, optional (default=False)
237+
If True, returns warm start data in the log for resuming computation.
238+
If dict (from previous call with warm_start=True), resumes optimization
239+
from the provided state. Requires log=True when saving state.
246240
247241
248242
Returns
@@ -252,7 +246,7 @@ def emd(
252246
parameters
253247
log: dict, optional
254248
If input log is true, a dictionary containing the
255-
cost and dual variables and exit status. If return_checkpoint=True,
249+
cost and dual variables and exit status. If warm_start=True,
256250
also contains internal solver state for resuming computation.
257251
258252
@@ -270,6 +264,13 @@ def emd(
270264
array([[0.5, 0. ],
271265
[0. , 0.5]])
272266
267+
Warm start example for resuming optimization:
268+
269+
>>> # First call - save warm start data
270+
>>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True)
271+
>>> # Resume from warm start
272+
>>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log)
273+
273274
274275
.. _references-emd:
275276
References
@@ -333,39 +334,67 @@ def emd(
333334

334335
numThreads = check_number_threads(numThreads)
335336

337+
# Handle warm_start parameter
336338
checkpoint_data = None
337-
if checkpoint is not None:
338-
# Extract checkpoint arrays and convert to numpy (strip leading underscore)
339+
return_checkpoint = False
340+
341+
if isinstance(warm_start, dict):
342+
# Resume from previous warm_start dict
339343
checkpoint_data = {
340-
"flow": nx.to_numpy(checkpoint["_flow"]) if "_flow" in checkpoint else None,
341-
"pi": nx.to_numpy(checkpoint["_pi"]) if "_pi" in checkpoint else None,
342-
"state": nx.to_numpy(checkpoint["_state"])
343-
if "_state" in checkpoint
344+
"flow": nx.to_numpy(warm_start.get("_flow", warm_start.get("flow")))
345+
if ("_flow" in warm_start or "flow" in warm_start)
344346
else None,
345-
"parent": nx.to_numpy(checkpoint["_parent"])
346-
if "_parent" in checkpoint
347+
"pi": nx.to_numpy(warm_start.get("_pi", warm_start.get("pi")))
348+
if ("_pi" in warm_start or "pi" in warm_start)
347349
else None,
348-
"pred": nx.to_numpy(checkpoint["_pred"]) if "_pred" in checkpoint else None,
349-
"thread": nx.to_numpy(checkpoint["_thread"])
350-
if "_thread" in checkpoint
350+
"state": nx.to_numpy(warm_start.get("_state", warm_start.get("state")))
351+
if ("_state" in warm_start or "state" in warm_start)
351352
else None,
352-
"rev_thread": nx.to_numpy(checkpoint["_rev_thread"])
353-
if "_rev_thread" in checkpoint
353+
"parent": nx.to_numpy(warm_start.get("_parent", warm_start.get("parent")))
354+
if ("_parent" in warm_start or "parent" in warm_start)
354355
else None,
355-
"succ_num": nx.to_numpy(checkpoint["_succ_num"])
356-
if "_succ_num" in checkpoint
356+
"pred": nx.to_numpy(warm_start.get("_pred", warm_start.get("pred")))
357+
if ("_pred" in warm_start or "pred" in warm_start)
357358
else None,
358-
"last_succ": nx.to_numpy(checkpoint["_last_succ"])
359-
if "_last_succ" in checkpoint
359+
"thread": nx.to_numpy(warm_start.get("_thread", warm_start.get("thread")))
360+
if ("_thread" in warm_start or "thread" in warm_start)
360361
else None,
361-
"forward": nx.to_numpy(checkpoint["_forward"])
362-
if "_forward" in checkpoint
362+
"rev_thread": nx.to_numpy(
363+
warm_start.get("_rev_thread", warm_start.get("rev_thread"))
364+
)
365+
if ("_rev_thread" in warm_start or "rev_thread" in warm_start)
366+
else None,
367+
"succ_num": nx.to_numpy(
368+
warm_start.get("_succ_num", warm_start.get("succ_num"))
369+
)
370+
if ("_succ_num" in warm_start or "succ_num" in warm_start)
363371
else None,
364-
"search_arc_num": int(checkpoint.get("search_arc_num", 0)),
365-
"all_arc_num": int(checkpoint.get("all_arc_num", 0)),
372+
"last_succ": nx.to_numpy(
373+
warm_start.get("_last_succ", warm_start.get("last_succ"))
374+
)
375+
if ("_last_succ" in warm_start or "last_succ" in warm_start)
376+
else None,
377+
"forward": nx.to_numpy(
378+
warm_start.get("_forward", warm_start.get("forward"))
379+
)
380+
if ("_forward" in warm_start or "forward" in warm_start)
381+
else None,
382+
"search_arc_num": int(
383+
warm_start.get("search_arc_num", warm_start.get("_search_arc_num", 0))
384+
),
385+
"all_arc_num": int(
386+
warm_start.get("all_arc_num", warm_start.get("_all_arc_num", 0))
387+
),
366388
}
367389
# Filter out None values
368390
checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None}
391+
elif warm_start is True:
392+
# Save warm_start data - requires log=True
393+
if not log:
394+
raise ValueError(
395+
"warm_start=True requires log=True to return the warm start data"
396+
)
397+
return_checkpoint = True
369398

370399
G, cost, u, v, result_code, checkpoint_out = emd_c(
371400
a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint)

test/test_ot.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -924,12 +924,10 @@ def test_emd_checkpoint():
924924

925925
G_ref, log_ref = ot.emd(a, b, M, numItermax=10000, log=True)
926926

927-
G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, return_checkpoint=True)
927+
G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, warm_start=True)
928928

929929
if log1["result_code"] == 3: # MAX_ITER_REACHED ?
930-
G2, log2 = ot.emd(
931-
a, b, M, numItermax=10000, log=True, checkpoint=log1, return_checkpoint=True
932-
)
930+
G2, log2 = ot.emd(a, b, M, numItermax=10000, log=True, warm_start=log1)
933931

934932
np.testing.assert_allclose(log2["cost"], log_ref["cost"], rtol=1e-6)
935933
np.testing.assert_allclose(G2, G_ref, rtol=1e-6)
@@ -947,24 +945,34 @@ def test_emd_checkpoint_multiple():
947945

948946
# multiple checkpoint phases with increasing iteration budgets
949947
max_iters = [100, 300, 600, 1000]
950-
checkpoint = None
948+
warm_start_data = None
951949
costs = []
952950

953951
for max_iter in max_iters:
954-
G, log = ot.emd(
955-
a,
956-
b,
957-
M,
958-
numItermax=max_iter,
959-
log=True,
960-
checkpoint=checkpoint,
961-
return_checkpoint=True,
962-
)
952+
if warm_start_data is None:
953+
G, log = ot.emd(
954+
a,
955+
b,
956+
M,
957+
numItermax=max_iter,
958+
log=True,
959+
warm_start=True,
960+
)
961+
else:
962+
G, log = ot.emd(
963+
a,
964+
b,
965+
M,
966+
numItermax=max_iter,
967+
log=True,
968+
warm_start=warm_start_data,
969+
)
963970
costs.append(log["cost"])
964971

965972
if log["result_code"] != 3: # converged
966973
break
967-
checkpoint = log
974+
# Only use warm_start if checkpoint fields are present
975+
warm_start_data = log if "_flow" in log else None
968976

969977
# check cost decreases monotonically
970978
for i in range(len(costs) - 1):
@@ -981,7 +989,7 @@ def test_emd_checkpoint_structure():
981989
b = ot.utils.unif(n)
982990
M = np.random.rand(n, n)
983991

984-
G, log = ot.emd(a, b, M, numItermax=10, log=True, return_checkpoint=True)
992+
G, log = ot.emd(a, b, M, numItermax=10, log=True, warm_start=True)
985993

986994
required_fields = [
987995
"_flow",
@@ -994,8 +1002,8 @@ def test_emd_checkpoint_structure():
9941002
"_succ_num",
9951003
"_last_succ",
9961004
"_forward",
997-
"_search_arc_num",
998-
"_all_arc_num",
1005+
"search_arc_num", # scalars don't have underscore prefix
1006+
"all_arc_num",
9991007
]
10001008

10011009
for field in required_fields:

0 commit comments

Comments
 (0)