Skip to content

Commit b184cd4

Browse files
committed
Added Example for documentation and modified back setup file to original file
1 parent 1e28771 commit b184cd4

File tree

3 files changed

+307
-13
lines changed

3 files changed

+307
-13
lines changed
280 KB
Loading

examples/plot_sparse_emd.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================================
4+
Sparse Optimal Transport
5+
============================================
6+
7+
In many real-world optimal transport (OT) problems, the transport plan is naturally sparse: only a small fraction of all possible source-target pairs actually exchange mass. In such cases, using sparse OT solvers can provide significant computational speedups and memory savings compared to dense solvers, which compute and store the full transport matrix.
8+
9+
The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan.
10+
11+
.. image:: /_static/images/comparison.png
12+
:align: center
13+
:width: 80%
14+
:alt: Dense vs Sparse OT: Speed and Memory Advantages
15+
"""
16+
17+
18+
# Author: Nathan Neike <nathan.neike@example.com>
19+
# License: MIT License
20+
# sphinx_gallery_thumbnail_number = 2
21+
22+
import numpy as np
23+
import matplotlib.pyplot as plt
24+
from scipy.sparse import coo_matrix
25+
import ot
26+
27+
28+
##############################################################################
29+
# Generate minimal example data
30+
# ------------------------------
31+
#
32+
# We create a simple example with 2 source points and 2 target points to
33+
# illustrate the concept of sparse optimal transport.
34+
35+
# %%
36+
37+
X = np.array([[0, 0], [1, 0]])
38+
Y = np.array([[0, 1], [1, 1]])
39+
a = np.array([0.5, 0.5])
40+
b = np.array([0.5, 0.5])
41+
42+
43+
##############################################################################
44+
# Build sparse cost matrix
45+
# -------------------------
46+
#
47+
# Instead of allowing all possible edges (dense OT), we only allow two edges:
48+
# source 0 -> target 0 and source 1 -> target 1. This is specified using a
49+
# sparse matrix format (COO).
50+
51+
# %%
52+
53+
# Only allow two edges: source 0 -> target 0, source 1 -> target 1
54+
rows = [0, 1]
55+
cols = [0, 1]
56+
vals = [np.linalg.norm(X[0] - Y[0]), np.linalg.norm(X[1] - Y[1])]
57+
M_sparse = coo_matrix((vals, (rows, cols)), shape=(2, 2))
58+
59+
60+
##############################################################################
61+
# Solve sparse OT problem
62+
# ------------------------
63+
#
64+
# When passing a sparse cost matrix to ot.emd with log=True, the solution
65+
# is returned in the log dictionary with fields 'flow_sources', 'flow_targets',
66+
# and 'flow_values' containing the edge information.
67+
68+
# %%
69+
70+
G, log = ot.emd(a, b, M_sparse, log=True)
71+
72+
print("Sparse OT cost:", log["cost"])
73+
print("Edges:")
74+
for i, j, v in zip(log["flow_sources"], log["flow_targets"], log["flow_values"]):
75+
print(f" source {i} -> target {j}, flow={v:.3f}")
76+
77+
78+
##############################################################################
79+
# Visualize allowed edges
80+
# ---------------------------------
81+
#
82+
# The sparse cost matrix only allows transport along specific edges.
83+
84+
# %%
85+
86+
87+
plt.figure(figsize=(8, 4))
88+
89+
# Sparse OT: allowed edges only
90+
plt.subplot(1, 2, 1)
91+
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
92+
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
93+
for i, j in zip(rows, cols):
94+
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.6)
95+
plt.title("Sparse OT: Allowed Edges Only")
96+
97+
plt.xlim(-0.5, 1.5)
98+
plt.ylim(-0.5, 1.5)
99+
plt.xticks([0, 1])
100+
plt.yticks([0, 1])
101+
102+
# Dense OT: all possible edges
103+
plt.subplot(1, 2, 2)
104+
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
105+
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
106+
for i in range(2):
107+
for j in range(2):
108+
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.3)
109+
plt.title("Dense OT: All Possible Edges")
110+
plt.xlim(-0.5, 1.5)
111+
plt.ylim(-0.5, 1.5)
112+
plt.xticks([0, 1])
113+
plt.yticks([0, 1])
114+
115+
plt.tight_layout()
116+
117+
118+
##############################################################################
119+
# Larger example with clusters
120+
# --------------------------------------
121+
#
122+
# Now we create a more realistic example with multiple clusters of sources
123+
# and targets, where transport is only allowed within each cluster.
124+
125+
# %%
126+
127+
grid_size = 4
128+
n_clusters = grid_size * grid_size
129+
points_per_cluster = 2
130+
cluster_spacing = 15.0
131+
intra_cluster_spacing = 1.5
132+
cluster_centers = (
133+
np.array([[i, j] for i in range(grid_size) for j in range(grid_size)])
134+
* cluster_spacing
135+
)
136+
137+
X_large = []
138+
Y_large = []
139+
a_large = []
140+
b_large = []
141+
142+
for idx, (cx, cy) in enumerate(cluster_centers):
143+
for i in range(points_per_cluster):
144+
X_large.append(
145+
[cx + intra_cluster_spacing * (i - 1), cy - intra_cluster_spacing]
146+
)
147+
a_large.append(1.0 / (n_clusters * points_per_cluster))
148+
149+
for i in range(points_per_cluster):
150+
Y_large.append(
151+
[cx + intra_cluster_spacing * (i - 1), cy + intra_cluster_spacing]
152+
)
153+
b_large.append(1.0 / (n_clusters * points_per_cluster))
154+
155+
X_large = np.array(X_large)
156+
Y_large = np.array(Y_large)
157+
a_large = np.array(a_large)
158+
b_large = np.array(b_large)
159+
160+
nA = nB = n_clusters * points_per_cluster
161+
source_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
162+
sink_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
163+
164+
165+
##############################################################################
166+
# Build sparse cost matrix (intra-cluster only)
167+
# ----------------------------------------------
168+
#
169+
# We construct a sparse cost matrix that only includes edges within each cluster.
170+
171+
# %%
172+
173+
M_full = ot.dist(X_large, Y_large, metric="euclidean")
174+
175+
rows = []
176+
cols = []
177+
vals = []
178+
for k in range(n_clusters):
179+
src_idx = np.where(source_labels == k)[0]
180+
sink_idx = np.where(sink_labels == k)[0]
181+
for i in src_idx:
182+
for j in sink_idx:
183+
rows.append(i)
184+
cols.append(j)
185+
vals.append(M_full[i, j])
186+
M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(nA, nB))
187+
188+
189+
##############################################################################
190+
# Visualize allowed edges structure
191+
# ----------------------------------
192+
#
193+
# Dense OT allows all connections, while sparse OT restricts to intra-cluster edges.
194+
195+
# %%
196+
197+
plt.figure(figsize=(16, 6))
198+
199+
# Dense OT: all possible edges
200+
plt.subplot(1, 2, 1)
201+
for i in range(nA):
202+
for j in range(nB):
203+
plt.plot(
204+
[X_large[i, 0], Y_large[j, 0]],
205+
[X_large[i, 1], Y_large[j, 1]],
206+
color="blue",
207+
alpha=0.1,
208+
linewidth=0.7,
209+
)
210+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
211+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
212+
plt.axis("equal")
213+
plt.title("Dense OT: All Possible Edges")
214+
215+
# Sparse OT: only intra-cluster edges
216+
plt.subplot(1, 2, 2)
217+
for k in range(n_clusters):
218+
src_idx = np.where(source_labels == k)[0]
219+
sink_idx = np.where(sink_labels == k)[0]
220+
for i in src_idx:
221+
for j in sink_idx:
222+
plt.plot(
223+
[X_large[i, 0], Y_large[j, 0]],
224+
[X_large[i, 1], Y_large[j, 1]],
225+
color="blue",
226+
alpha=0.7,
227+
linewidth=1.5,
228+
)
229+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
230+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
231+
plt.axis("equal")
232+
plt.title("Sparse OT: Only Intra-Cluster Edges")
233+
234+
plt.tight_layout()
235+
236+
237+
##############################################################################
238+
# Solve and compare sparse vs dense OT
239+
# -------------------------------------
240+
#
241+
# We solve both dense and sparse OT problems and verify that they produce
242+
# the same optimal solution when the sparse edges include the optimal paths.
243+
244+
# %%
245+
246+
# Solve dense OT (full cost matrix)
247+
G_dense = ot.emd(a_large, b_large, M_full)
248+
cost_dense = np.sum(G_dense * M_full)
249+
print(f"Dense OT cost: {cost_dense:.6f}")
250+
251+
# Solve sparse OT (intra-cluster only)
252+
G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True)
253+
cost_sparse = log_sparse["cost"]
254+
print(f"Sparse OT cost: {cost_sparse:.6f}")
255+
256+
257+
##############################################################################
258+
# Visualize optimal transport plans
259+
# ----------------------------------
260+
#
261+
# Plot the edges that carry flow in the optimal solutions.
262+
263+
# %%
264+
265+
plt.figure(figsize=(16, 6))
266+
267+
# Dense OT
268+
plt.subplot(1, 2, 1)
269+
for i in range(nA):
270+
for j in range(nB):
271+
if G_dense[i, j] > 1e-10:
272+
plt.plot(
273+
[X_large[i, 0], Y_large[j, 0]],
274+
[X_large[i, 1], Y_large[j, 1]],
275+
color="blue",
276+
alpha=0.7,
277+
linewidth=1.5,
278+
)
279+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
280+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
281+
plt.axis("equal")
282+
plt.title("Dense OT: Optimal Transport Plan")
283+
284+
# Sparse OT
285+
plt.subplot(1, 2, 2)
286+
if log_sparse["flow_sources"] is not None:
287+
for i, j, v in zip(
288+
log_sparse["flow_sources"],
289+
log_sparse["flow_targets"],
290+
log_sparse["flow_values"],
291+
):
292+
if v > 1e-10:
293+
plt.plot(
294+
[X_large[i, 0], Y_large[j, 0]],
295+
[X_large[i, 1], Y_large[j, 1]],
296+
color="blue",
297+
alpha=0.7,
298+
linewidth=1.5,
299+
)
300+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
301+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
302+
plt.axis("equal")
303+
plt.title("Sparse OT: Optimal Transport Plan")
304+
305+
plt.tight_layout()
306+
plt.show()

setup.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,7 @@
5050
link_args += flags
5151

5252
if sys.platform.startswith("darwin"):
53-
# Only add -stdlib=libc++ for Clang, not GCC
54-
# GCC uses libstdc++ by default and doesn't recognize -stdlib flag
55-
import subprocess
56-
try:
57-
# Check if using clang
58-
compiler = os.environ.get('CXX', 'c++')
59-
version_output = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT).decode()
60-
if 'clang' in version_output.lower():
61-
compile_args.append("-stdlib=libc++")
62-
except Exception:
63-
# If we can't determine, don't add the flag (safer for GCC)
64-
pass
65-
53+
compile_args.append("-stdlib=libc++")
6654
sdk_path = subprocess.check_output(["xcrun", "--show-sdk-path"])
6755
os.environ["CFLAGS"] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))
6856

0 commit comments

Comments
 (0)