Skip to content

Commit b6736b5

Browse files
committed
One field for checkpoint, cleaner interface
1 parent 31a382d commit b6736b5

File tree

2 files changed

+64
-59
lines changed

2 files changed

+64
-59
lines changed

ot/lp/_network_simplex.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ def emd(
247247
log: dict, optional
248248
If input log is true, a dictionary containing the
249249
cost and dual variables and exit status. If warm_start=True,
250-
also contains internal solver state for resuming computation.
250+
also contains a "checkpoint" key with the internal solver state
251+
for resuming computation.
251252
252253
253254
Examples
@@ -268,6 +269,7 @@ def emd(
268269
269270
>>> # First call - save warm start data
270271
>>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True)
272+
>>> # log["checkpoint"] contains the solver state
271273
>>> # Resume from warm start
272274
>>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log)
273275
@@ -340,51 +342,47 @@ def emd(
340342

341343
if isinstance(warm_start, dict):
342344
# Resume from previous warm_start dict
345+
# Check if checkpoint is nested under "checkpoint" key or at top level
346+
if "checkpoint" in warm_start:
347+
chkpt = warm_start["checkpoint"]
348+
else:
349+
chkpt = warm_start
350+
343351
checkpoint_data = {
344-
"flow": nx.to_numpy(warm_start.get("_flow", warm_start.get("flow")))
345-
if ("_flow" in warm_start or "flow" in warm_start)
352+
"flow": nx.to_numpy(chkpt.get("flow", chkpt.get("_flow")))
353+
if ("flow" in chkpt or "_flow" in chkpt)
346354
else None,
347-
"pi": nx.to_numpy(warm_start.get("_pi", warm_start.get("pi")))
348-
if ("_pi" in warm_start or "pi" in warm_start)
355+
"pi": nx.to_numpy(chkpt.get("pi", chkpt.get("_pi")))
356+
if ("pi" in chkpt or "_pi" in chkpt)
349357
else None,
350-
"state": nx.to_numpy(warm_start.get("_state", warm_start.get("state")))
351-
if ("_state" in warm_start or "state" in warm_start)
358+
"state": nx.to_numpy(chkpt.get("state", chkpt.get("_state")))
359+
if ("state" in chkpt or "_state" in chkpt)
352360
else None,
353-
"parent": nx.to_numpy(warm_start.get("_parent", warm_start.get("parent")))
354-
if ("_parent" in warm_start or "parent" in warm_start)
361+
"parent": nx.to_numpy(chkpt.get("parent", chkpt.get("_parent")))
362+
if ("parent" in chkpt or "_parent" in chkpt)
355363
else None,
356-
"pred": nx.to_numpy(warm_start.get("_pred", warm_start.get("pred")))
357-
if ("_pred" in warm_start or "pred" in warm_start)
364+
"pred": nx.to_numpy(chkpt.get("pred", chkpt.get("_pred")))
365+
if ("pred" in chkpt or "_pred" in chkpt)
358366
else None,
359-
"thread": nx.to_numpy(warm_start.get("_thread", warm_start.get("thread")))
360-
if ("_thread" in warm_start or "thread" in warm_start)
367+
"thread": nx.to_numpy(chkpt.get("thread", chkpt.get("_thread")))
368+
if ("thread" in chkpt or "_thread" in chkpt)
361369
else None,
362-
"rev_thread": nx.to_numpy(
363-
warm_start.get("_rev_thread", warm_start.get("rev_thread"))
364-
)
365-
if ("_rev_thread" in warm_start or "rev_thread" in warm_start)
370+
"rev_thread": nx.to_numpy(chkpt.get("rev_thread", chkpt.get("_rev_thread")))
371+
if ("rev_thread" in chkpt or "_rev_thread" in chkpt)
366372
else None,
367-
"succ_num": nx.to_numpy(
368-
warm_start.get("_succ_num", warm_start.get("succ_num"))
369-
)
370-
if ("_succ_num" in warm_start or "succ_num" in warm_start)
373+
"succ_num": nx.to_numpy(chkpt.get("succ_num", chkpt.get("_succ_num")))
374+
if ("succ_num" in chkpt or "_succ_num" in chkpt)
371375
else None,
372-
"last_succ": nx.to_numpy(
373-
warm_start.get("_last_succ", warm_start.get("last_succ"))
374-
)
375-
if ("_last_succ" in warm_start or "last_succ" in warm_start)
376+
"last_succ": nx.to_numpy(chkpt.get("last_succ", chkpt.get("_last_succ")))
377+
if ("last_succ" in chkpt or "_last_succ" in chkpt)
376378
else None,
377-
"forward": nx.to_numpy(
378-
warm_start.get("_forward", warm_start.get("forward"))
379-
)
380-
if ("_forward" in warm_start or "forward" in warm_start)
379+
"forward": nx.to_numpy(chkpt.get("forward", chkpt.get("_forward")))
380+
if ("forward" in chkpt or "_forward" in chkpt)
381381
else None,
382382
"search_arc_num": int(
383-
warm_start.get("search_arc_num", warm_start.get("_search_arc_num", 0))
384-
),
385-
"all_arc_num": int(
386-
warm_start.get("all_arc_num", warm_start.get("_all_arc_num", 0))
383+
chkpt.get("search_arc_num", chkpt.get("_search_arc_num", 0))
387384
),
385+
"all_arc_num": int(chkpt.get("all_arc_num", chkpt.get("_all_arc_num", 0))),
388386
}
389387
# Filter out None values
390388
checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None}
@@ -425,18 +423,20 @@ def emd(
425423

426424
# Add checkpoint data if requested (preserve original dtypes, don't cast)
427425
if return_checkpoint and checkpoint_out is not None:
428-
log["_flow"] = checkpoint_out["flow"]
429-
log["_pi"] = checkpoint_out["pi"]
430-
log["_state"] = checkpoint_out["state"]
431-
log["_parent"] = checkpoint_out["parent"]
432-
log["_pred"] = checkpoint_out["pred"]
433-
log["_thread"] = checkpoint_out["thread"]
434-
log["_rev_thread"] = checkpoint_out["rev_thread"]
435-
log["_succ_num"] = checkpoint_out["succ_num"]
436-
log["_last_succ"] = checkpoint_out["last_succ"]
437-
log["_forward"] = checkpoint_out["forward"]
438-
log["search_arc_num"] = int(checkpoint_out["search_arc_num"])
439-
log["all_arc_num"] = int(checkpoint_out["all_arc_num"])
426+
log["checkpoint"] = {
427+
"flow": checkpoint_out["flow"],
428+
"pi": checkpoint_out["pi"],
429+
"state": checkpoint_out["state"],
430+
"parent": checkpoint_out["parent"],
431+
"pred": checkpoint_out["pred"],
432+
"thread": checkpoint_out["thread"],
433+
"rev_thread": checkpoint_out["rev_thread"],
434+
"succ_num": checkpoint_out["succ_num"],
435+
"last_succ": checkpoint_out["last_succ"],
436+
"forward": checkpoint_out["forward"],
437+
"search_arc_num": int(checkpoint_out["search_arc_num"]),
438+
"all_arc_num": int(checkpoint_out["all_arc_num"]),
439+
}
440440

441441
return nx.from_numpy(G, type_as=type_as), log
442442
return nx.from_numpy(G, type_as=type_as)

test/test_ot.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,8 @@ def test_emd_checkpoint_multiple():
971971

972972
if log["result_code"] != 3: # converged
973973
break
974-
# Only use warm_start if checkpoint fields are present
975-
warm_start_data = log if "_flow" in log else None
974+
# Only use warm_start if checkpoint is present
975+
warm_start_data = log if "checkpoint" in log else None
976976

977977
# check cost decreases monotonically
978978
for i in range(len(costs) - 1):
@@ -991,23 +991,28 @@ def test_emd_checkpoint_structure():
991991

992992
G, log = ot.emd(a, b, M, numItermax=10, log=True, warm_start=True)
993993

994+
# Check that checkpoint key exists
995+
assert "checkpoint" in log, "Missing checkpoint key in log"
996+
997+
checkpoint = log["checkpoint"]
998+
994999
required_fields = [
995-
"_flow",
996-
"_pi",
997-
"_state",
998-
"_parent",
999-
"_pred",
1000-
"_thread",
1001-
"_rev_thread",
1002-
"_succ_num",
1003-
"_last_succ",
1004-
"_forward",
1005-
"search_arc_num", # scalars don't have underscore prefix
1000+
"flow",
1001+
"pi",
1002+
"state",
1003+
"parent",
1004+
"pred",
1005+
"thread",
1006+
"rev_thread",
1007+
"succ_num",
1008+
"last_succ",
1009+
"forward",
1010+
"search_arc_num",
10061011
"all_arc_num",
10071012
]
10081013

10091014
for field in required_fields:
1010-
assert field in log, f"Missing checkpoint field: {field}"
1015+
assert field in checkpoint, f"Missing checkpoint field: {field}"
10111016

10121017

10131018
def check_duality_gap(a, b, M, G, u, v, cost):

0 commit comments

Comments
 (0)