Skip to content

Commit 76d6283

Browse files
committed
Add checkpoint/resume functionality to EMD solver
Implements pause and resume capabilities for the EMD (Earth Mover's Distance) solver, allowing long-running optimizations to be interrupted and continued from their exact state. Changes: - Modified EMD_wrap() signature to accept checkpoint parameters for saving and restoring complete internal solver state (flow, potentials, tree structure) - Added saveCheckpoint(), restoreCheckpoint(), and runFromCheckpoint() methods to NetworkSimplexSimple class in network_simplex_simple.h - Extended emd_c() Cython wrapper to handle checkpoint dictionaries with 12 fields (10 arrays + 2 scalar arc counts) - Added 'checkpoint' and 'return_checkpoint' parameters to emd() Python API - Includes search_arc_num and all_arc_num scalars in checkpoint to preserve initialization state required by start() method Tests: - Added test_emd_checkpoint() for basic save/resume functionality - Added test_emd_checkpoint_multiple() for multiple pause/resume cycles - Added test_emd_checkpoint_structure() to verify checkpoint field integrity
1 parent d3867c6 commit 76d6283

File tree

6 files changed

+458
-34
lines changed

6 files changed

+458
-34
lines changed

ot/lp/EMD.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,16 @@ enum ProblemType {
2929
MAX_ITER_REACHED
3030
};
3131

32-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
32+
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
33+
double* alpha, double* beta, double *cost, uint64_t maxIter,
34+
int resume_mode=0, int return_checkpoint=0,
35+
double* flow_state=nullptr, double* pi_state=nullptr,
36+
signed char* state_state=nullptr, int* parent_state=nullptr,
37+
int64_t* pred_state=nullptr, int* thread_state=nullptr,
38+
int* rev_thread_state=nullptr, int* succ_num_state=nullptr,
39+
int* last_succ_state=nullptr, signed char* forward_state=nullptr,
40+
int64_t* search_arc_num_out=nullptr, int64_t* all_arc_num_out=nullptr);
41+
3342
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
3443

3544

ot/lp/EMD_wrapper.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020

2121

2222
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
23-
double* alpha, double* beta, double *cost, uint64_t maxIter) {
23+
double* alpha, double* beta, double *cost, uint64_t maxIter,
24+
int resume_mode, int return_checkpoint,
25+
double* flow_state, double* pi_state, signed char* state_state,
26+
int* parent_state, int64_t* pred_state,
27+
int* thread_state, int* rev_thread_state,
28+
int* succ_num_state, int* last_succ_state,
29+
signed char* forward_state,
30+
int64_t* search_arc_num_out, int64_t* all_arc_num_out) {
2431
// beware M and C are stored in row major C style!!!
2532

2633
using namespace lemon;
@@ -93,8 +100,29 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
93100

94101

95102
// Solve the problem with the network simplex algorithm
96-
97-
int ret=net.run();
103+
// If resume_mode=1 and checkpoint data provided, resume from checkpoint
104+
// Otherwise do normal run
105+
106+
int64_t search_arc_num_in = 0, all_arc_num_in = 0;
107+
if (resume_mode == 1 && search_arc_num_out != nullptr && all_arc_num_out != nullptr) {
108+
search_arc_num_in = *search_arc_num_out;
109+
all_arc_num_in = *all_arc_num_out;
110+
}
111+
112+
int ret;
113+
if (resume_mode == 1 && flow_state != nullptr) {
114+
// Resume from checkpoint
115+
ret = net.runFromCheckpoint(
116+
flow_state, pi_state, state_state,
117+
parent_state, pred_state,
118+
thread_state, rev_thread_state,
119+
succ_num_state, last_succ_state, forward_state,
120+
search_arc_num_in, all_arc_num_in);
121+
} else {
122+
// Normal run
123+
ret = net.run();
124+
}
125+
98126
uint64_t i, j;
99127
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
100128
*cost = 0;
@@ -111,16 +139,22 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
111139

112140
}
113141

142+
// Save checkpoint if requested and arrays provided
143+
if (return_checkpoint == 1 && flow_state != nullptr) {
144+
net.saveCheckpoint(
145+
flow_state, pi_state, state_state,
146+
parent_state, pred_state,
147+
thread_state, rev_thread_state,
148+
succ_num_state, last_succ_state, forward_state,
149+
search_arc_num_out, all_arc_num_out);
150+
}
114151

115152
return ret;
116153
}
117154

118155

119156

120157

121-
122-
123-
124158
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
125159
double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
126160
// beware M and C are stored in row major C style!!!

ot/lp/_network_simplex.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def emd(
172172
center_dual=True,
173173
numThreads=1,
174174
check_marginals=True,
175+
checkpoint=None,
176+
return_checkpoint=False,
175177
):
176178
r"""Solves the Earth Movers distance problem and returns the OT matrix
177179
@@ -232,6 +234,15 @@ def emd(
232234
check_marginals: bool, optional (default=True)
233235
If True, checks that the marginals mass are equal. If False, skips the
234236
check.
237+
checkpoint: dict, optional (default=None)
238+
Checkpoint data from a previous emd() call to resume computation.
239+
The checkpoint must contain internal solver state including flow,
240+
potentials, and tree structure. Obtain by calling emd() with
241+
return_checkpoint=True.
242+
return_checkpoint: bool, optional (default=False)
243+
If True and log=True, includes complete internal solver state in the
244+
returned log dictionary for checkpointing. This enables pausing and
245+
resuming the optimization.
235246
236247
237248
Returns
@@ -241,7 +252,8 @@ def emd(
241252
parameters
242253
log: dict, optional
243254
If input log is true, a dictionary containing the
244-
cost and dual variables and exit status
255+
cost and dual variables and exit status. If return_checkpoint=True,
256+
also contains internal solver state for resuming computation.
245257
246258
247259
Examples
@@ -321,7 +333,43 @@ def emd(
321333

322334
numThreads = check_number_threads(numThreads)
323335

324-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
336+
checkpoint_data = None
337+
if checkpoint is not None:
338+
# Extract checkpoint arrays and convert to numpy (strip leading underscore)
339+
checkpoint_data = {
340+
"flow": nx.to_numpy(checkpoint["_flow"]) if "_flow" in checkpoint else None,
341+
"pi": nx.to_numpy(checkpoint["_pi"]) if "_pi" in checkpoint else None,
342+
"state": nx.to_numpy(checkpoint["_state"])
343+
if "_state" in checkpoint
344+
else None,
345+
"parent": nx.to_numpy(checkpoint["_parent"])
346+
if "_parent" in checkpoint
347+
else None,
348+
"pred": nx.to_numpy(checkpoint["_pred"]) if "_pred" in checkpoint else None,
349+
"thread": nx.to_numpy(checkpoint["_thread"])
350+
if "_thread" in checkpoint
351+
else None,
352+
"rev_thread": nx.to_numpy(checkpoint["_rev_thread"])
353+
if "_rev_thread" in checkpoint
354+
else None,
355+
"succ_num": nx.to_numpy(checkpoint["_succ_num"])
356+
if "_succ_num" in checkpoint
357+
else None,
358+
"last_succ": nx.to_numpy(checkpoint["_last_succ"])
359+
if "_last_succ" in checkpoint
360+
else None,
361+
"forward": nx.to_numpy(checkpoint["_forward"])
362+
if "_forward" in checkpoint
363+
else None,
364+
"search_arc_num": int(checkpoint.get("search_arc_num", 0)),
365+
"all_arc_num": int(checkpoint.get("all_arc_num", 0)),
366+
}
367+
# Filter out None values
368+
checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None}
369+
370+
G, cost, u, v, result_code, checkpoint_out = emd_c(
371+
a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint)
372+
)
325373

326374
if center_dual:
327375
u, v = center_ot_dual(u, v, a, b)
@@ -345,6 +393,22 @@ def emd(
345393
log["v"] = nx.from_numpy(v, type_as=type_as)
346394
log["warning"] = result_code_string
347395
log["result_code"] = result_code
396+
397+
# Add checkpoint data if requested (preserve original dtypes, don't cast)
398+
if return_checkpoint and checkpoint_out is not None:
399+
log["_flow"] = checkpoint_out["flow"]
400+
log["_pi"] = checkpoint_out["pi"]
401+
log["_state"] = checkpoint_out["state"]
402+
log["_parent"] = checkpoint_out["parent"]
403+
log["_pred"] = checkpoint_out["pred"]
404+
log["_thread"] = checkpoint_out["thread"]
405+
log["_rev_thread"] = checkpoint_out["rev_thread"]
406+
log["_succ_num"] = checkpoint_out["succ_num"]
407+
log["_last_succ"] = checkpoint_out["last_succ"]
408+
log["_forward"] = checkpoint_out["forward"]
409+
log["search_arc_num"] = int(checkpoint_out["search_arc_num"])
410+
log["all_arc_num"] = int(checkpoint_out["all_arc_num"])
411+
348412
return nx.from_numpy(G, type_as=type_as), log
349413
return nx.from_numpy(G, type_as=type_as)
350414

0 commit comments

Comments
 (0)