Skip to content

Commit 4885368

Browse files
committed
Optimize sparse EMD with sklearn and code cleanup
- Use sklearn.NearestNeighbors in dist_knn() (1.4x faster) - Remove redundant test code (~50 lines) - Migrate coo_matrix → coo_array - Fix parameter ordering consistency
1 parent 54479d5 commit 4885368

File tree

6 files changed

+152
-374
lines changed

6 files changed

+152
-374
lines changed
-280 KB
Binary file not shown.

examples/plot_sparse_emd.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,78 +20,11 @@
2020

2121
import numpy as np
2222
import matplotlib.pyplot as plt
23-
from scipy.sparse import coo_matrix
23+
from scipy.sparse import coo_array
2424
import ot
2525

26-
27-
##############################################################################
28-
# Minimal example with 4 points
29-
# ------------------------------
30-
31-
# %%
32-
33-
X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]])
34-
Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]])
35-
a = np.array([0.25, 0.25, 0.25, 0.25])
36-
b = np.array([0.25, 0.25, 0.25, 0.25])
37-
38-
# Build sparse cost matrix allowing only selected edges
39-
rows = [0, 1, 2, 3]
40-
cols = [0, 1, 2, 3]
41-
vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)]
42-
M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4))
43-
44-
45-
##############################################################################
46-
# Solve and display sparse OT solution
47-
# -------------------------------------
48-
49-
# %%
50-
51-
G, log = ot.emd(a, b, M_sparse, log=True)
52-
53-
print("Sparse OT cost:", log["cost"])
54-
print("Solution format:", type(G))
55-
print("Non-zero edges:", G.nnz)
56-
print("\nEdges:")
57-
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
58-
for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data):
59-
if v > 1e-10:
60-
print(f" source {i} -> target {j}, flow={v:.3f}")
61-
62-
63-
##############################################################################
64-
# Visualize sparse vs dense edge structure
65-
# -----------------------------------------
66-
67-
# %%
68-
69-
plt.figure(figsize=(8, 4))
70-
71-
plt.subplot(1, 2, 1)
72-
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
73-
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
74-
for i, j in zip(rows, cols):
75-
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6)
76-
plt.title("Sparse OT: Allowed Edges Only")
77-
plt.xlim(-0.5, 2.0)
78-
plt.ylim(-0.5, 1.5)
79-
80-
plt.subplot(1, 2, 2)
81-
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
82-
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
83-
for i in range(len(X)):
84-
for j in range(len(Y)):
85-
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3)
86-
plt.title("Dense OT: All Possible Edges")
87-
plt.xlim(-0.5, 2.0)
88-
plt.ylim(-0.5, 1.5)
89-
90-
plt.tight_layout()
91-
92-
9326
##############################################################################
94-
# Larger example: concentric circles
27+
# Example: concentric circles
9528
# -----------------------------------
9629

9730
# %%
@@ -144,7 +77,7 @@
14477
cols.append(j)
14578
vals.append(M_full[i, j])
14679

147-
M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n))
80+
M_sparse_large = coo_array((vals, (rows, cols)), shape=(n, n))
14881
allowed_sparse = set(zip(rows, cols))
14982

15083
##############################################################################

ot/lp/_network_simplex.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -289,15 +289,10 @@ def emd(
289289
ot.optim.cg : General regularized OT
290290
"""
291291

292-
edge_sources = None
293-
edge_targets = None
294-
edge_costs = None
295292
n1, n2 = None, None
296293

297-
# Get backend from M first, then use it for list_to_array
298-
# This ensures empty lists [] are converted to arrays in the correct backend
299-
nx_M = get_backend(M)
300-
a, b = list_to_array(a, b, nx=nx_M)
294+
# Convert lists to arrays, using M to detect backend when a,b are empty
295+
a, b, M = list_to_array(a, b, M)
301296
nx = get_backend(a, b, M)
302297

303298
# Check if M is sparse using backend's issparse method
@@ -325,15 +320,6 @@ def emd(
325320
if edge_costs.dtype != np.float64:
326321
edge_costs = edge_costs.astype(np.float64)
327322

328-
elif isinstance(M, tuple):
329-
raise ValueError(
330-
"Tuple format for sparse cost matrix is not supported. "
331-
"Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)."
332-
)
333-
else:
334-
is_sparse = False
335-
a, b, M = list_to_array(a, b, M)
336-
337323
if len(a) != 0:
338324
type_as = a
339325
elif len(b) != 0:
@@ -458,10 +444,10 @@ def emd2(
458444
processes=1,
459445
numItermax=100000,
460446
log=False,
447+
return_matrix=False,
461448
center_dual=True,
462449
numThreads=1,
463450
check_marginals=True,
464-
return_matrix=False,
465451
):
466452
r"""Solves the Earth Movers distance problem and returns the loss
467453
@@ -514,7 +500,7 @@ def emd2(
514500
The maximum number of iterations before stopping the optimization
515501
algorithm if it has not converged.
516502
log: boolean, optional (default=False)
517-
If True, returns a dictionary containing dual
503+
If True, returns a dictionary containing the cost and dual
518504
variables. Otherwise returns only the optimal transportation cost.
519505
return_matrix: boolean, optional (default=False)
520506
If True, returns the optimal transportation matrix in the log.
@@ -542,8 +528,9 @@ def emd2(
542528
W: float, array-like
543529
Optimal transportation loss for the given parameters
544530
log: dict
545-
If input log is true, a dictionary containing dual
546-
variables and exit status
531+
If input log is true, a dictionary containing the cost, dual
532+
variables (u, v), exit status, and optionally the optimal
533+
transportation matrix (G) if return_matrix is True
547534
548535
549536
Examples
@@ -575,15 +562,9 @@ def emd2(
575562
ot.optim.cg : General regularized OT
576563
"""
577564

578-
edge_sources = None
579-
edge_targets = None
580-
edge_costs = None
581565
n1, n2 = None, None
582566

583-
# Get backend from M first, then use it for list_to_array
584-
# This ensures empty lists [] are converted to arrays in the correct backend
585-
nx_M = get_backend(M)
586-
a, b = list_to_array(a, b, nx=nx_M)
567+
a, b, M = list_to_array(a, b, M)
587568
nx = get_backend(a, b, M)
588569

589570
# Check if M is sparse using backend's issparse method
@@ -596,43 +577,26 @@ def emd2(
596577
# Check if backend supports sparse matrices
597578
backend_name = nx.__class__.__name__
598579
if backend_name in ["JaxBackend", "TensorflowBackend"]:
599-
raise NotImplementedError(
600-
f"Sparse optimal transport is not supported for {backend_name}. "
601-
"JAX does not have native sparse matrix support, and TensorFlow's "
602-
"sparse implementation is incomplete. Please convert your sparse "
603-
"matrix to dense format using M.toarray() or equivalent before calling emd2()."
604-
)
580+
raise NotImplementedError()
605581

606582
# Save original M for gradient tracking (before numpy conversion)
607583
M_original_sparse = M
608584

609-
# Extract COO data using backend method - returns numpy arrays
610585
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)
611586

612-
# Ensure correct dtypes for C++ solver
613587
if edge_sources.dtype != np.uint64:
614588
edge_sources = edge_sources.astype(np.uint64)
615589
if edge_targets.dtype != np.uint64:
616590
edge_targets = edge_targets.astype(np.uint64)
617591
if edge_costs.dtype != np.float64:
618592
edge_costs = edge_costs.astype(np.float64)
619593

620-
elif isinstance(M, tuple):
621-
raise ValueError(
622-
"Tuple format for sparse cost matrix is not supported. "
623-
"Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)."
624-
)
625-
else:
626-
# Dense matrix
627-
is_sparse = False
628-
a, b, M = list_to_array(a, b, M)
629-
630594
if len(a) != 0:
631595
type_as = a
632596
elif len(b) != 0:
633597
type_as = b
634598
else:
635-
type_as = a # Can't use M for sparse case
599+
type_as = a
636600

637601
# Set n1, n2 if not already set (dense case)
638602
if n1 is None:
@@ -649,7 +613,6 @@ def emd2(
649613

650614
if is_sparse:
651615
# Use the original sparse tensor (preserves gradients for PyTorch)
652-
# instead of converting from numpy
653616
edge_costs_original = M_original_sparse
654617
else:
655618
edge_costs_original = None
@@ -682,12 +645,11 @@ def emd2(
682645
numThreads = check_number_threads(numThreads)
683646

684647
# ============================================================================
685-
# DEFINE SOLVER FUNCTION (works for both sparse and dense)
648+
# DEFINE SOLVER FUNCTION
686649
# ============================================================================
687650
def f(b):
688651
bsel = b != 0
689652

690-
# Call appropriate solver
691653
if is_sparse:
692654
# Solve sparse EMD
693655
flow_sources, flow_targets, flow_values, cost, u, v, result_code = (
@@ -745,6 +707,23 @@ def f(b):
745707
grad_M_sparse,
746708
),
747709
)
710+
711+
# Build transport plan in backend sparse format
712+
flow_values_backend = nx.from_numpy(flow_values, type_as=type_as)
713+
flow_sources_backend = nx.from_numpy(
714+
flow_sources.astype(np.int64), type_as=type_as
715+
)
716+
flow_targets_backend = nx.from_numpy(
717+
flow_targets.astype(np.int64), type_as=type_as
718+
)
719+
720+
G_backend = nx.coo_matrix(
721+
flow_values_backend,
722+
flow_sources_backend,
723+
flow_targets_backend,
724+
shape=(n1, n2),
725+
type_as=type_as,
726+
)
748727
else:
749728
# Dense case: warn about integer casting
750729
if not nx.is_floating_point(type_as):
@@ -772,20 +751,14 @@ def f(b):
772751
# Return results
773752
if log or return_matrix:
774753
log_dict = {
754+
"cost": cost,
775755
"u": nx.from_numpy(u, type_as=type_as),
776756
"v": nx.from_numpy(v, type_as=type_as),
777757
"warning": check_result(result_code),
778758
"result_code": result_code,
779759
}
780-
781760
if return_matrix:
782-
if is_sparse:
783-
G = np.zeros((len(a), len(b)), dtype=np.float64)
784-
G[flow_sources, flow_targets] = flow_values
785-
log_dict["G"] = nx.from_numpy(G, type_as=type_as)
786-
else:
787-
log_dict["G"] = G_backend
788-
761+
log_dict["G"] = G_backend
789762
return [cost, log_dict]
790763
else:
791764
return cost

ot/lp/emd_wrap.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,26 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
217217
np.ndarray[double, ndim=1, mode="c"] edge_costs,
218218
uint64_t max_iter):
219219
"""
220-
Sparse EMD solver - only considers edges in edge_sources/edge_targets
220+
Sparse EMD solver using cost matrix in COO (Coordinate) sparse format.
221+
222+
The cost matrix is passed as three parallel arrays representing non-zero
223+
entries in COO format: (edge_sources[i], edge_targets[i]) -> edge_costs[i].
224+
Only edges explicitly provided will be considered by the solver.
221225
222226
Parameters
223227
----------
224-
a : (n1,) array
228+
a : (n1,) array, float64
225229
Source histogram
226-
b : (n2,) array
230+
b : (n2,) array, float64
227231
Target histogram
228232
edge_sources : (k,) array, uint64
229-
Source indices for each edge
233+
Source indices for each edge (row indices in COO format)
230234
edge_targets : (k,) array, uint64
231-
Target indices for each edge
235+
Target indices for each edge (column indices in COO format)
232236
edge_costs : (k,) array, float64
233-
Cost for each edge
237+
Cost for each edge (non-zero values in COO format)
234238
max_iter : uint64_t
235-
Maximum iterations
239+
Maximum number of iterations
236240
237241
Returns
238242
-------

0 commit comments

Comments
 (0)