|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +============================================ |
| 4 | +Sparse Optimal Transport |
| 5 | +============================================ |
| 6 | +
|
| 7 | +In many real-world optimal transport (OT) problems, the transport plan is naturally sparse: only a small fraction of all possible source-target pairs actually exchange mass. In such cases, using sparse OT solvers can provide significant computational speedups and memory savings compared to dense solvers, which compute and store the full transport matrix. |
| 8 | +
|
| 9 | +The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan. |
| 10 | +
|
| 11 | +.. image:: /_static/images/comparison.png |
| 12 | + :align: center |
| 13 | + :width: 80% |
| 14 | + :alt: Dense vs Sparse OT: Speed and Memory Advantages |
| 15 | +""" |
| 16 | + |
| 17 | + |
| 18 | +# Author: Nathan Neike <nathan.neike@example.com> |
| 19 | +# License: MIT License |
| 20 | +# sphinx_gallery_thumbnail_number = 2 |
| 21 | + |
| 22 | +import numpy as np |
| 23 | +import matplotlib.pyplot as plt |
| 24 | +from scipy.sparse import coo_matrix |
| 25 | +import ot |
| 26 | + |
| 27 | + |
| 28 | +############################################################################## |
| 29 | +# Generate minimal example data |
| 30 | +# ------------------------------ |
| 31 | +# |
| 32 | +# We create a simple example with 2 source points and 2 target points to |
| 33 | +# illustrate the concept of sparse optimal transport. |
| 34 | + |
| 35 | +# %% |
| 36 | + |
| 37 | +X = np.array([[0, 0], [1, 0]]) |
| 38 | +Y = np.array([[0, 1], [1, 1]]) |
| 39 | +a = np.array([0.5, 0.5]) |
| 40 | +b = np.array([0.5, 0.5]) |
| 41 | + |
| 42 | + |
| 43 | +############################################################################## |
| 44 | +# Build sparse cost matrix |
| 45 | +# ------------------------- |
| 46 | +# |
| 47 | +# Instead of allowing all possible edges (dense OT), we only allow two edges: |
| 48 | +# source 0 -> target 0 and source 1 -> target 1. This is specified using a |
| 49 | +# sparse matrix format (COO). |
| 50 | + |
| 51 | +# %% |
| 52 | + |
| 53 | +# Only allow two edges: source 0 -> target 0, source 1 -> target 1 |
| 54 | +rows = [0, 1] |
| 55 | +cols = [0, 1] |
| 56 | +vals = [np.linalg.norm(X[0] - Y[0]), np.linalg.norm(X[1] - Y[1])] |
| 57 | +M_sparse = coo_matrix((vals, (rows, cols)), shape=(2, 2)) |
| 58 | + |
| 59 | + |
| 60 | +############################################################################## |
| 61 | +# Solve sparse OT problem |
| 62 | +# ------------------------ |
| 63 | +# |
| 64 | +# When passing a sparse cost matrix to ot.emd with log=True, the solution |
| 65 | +# is returned in the log dictionary with fields 'flow_sources', 'flow_targets', |
| 66 | +# and 'flow_values' containing the edge information. |
| 67 | + |
| 68 | +# %% |
| 69 | + |
| 70 | +G, log = ot.emd(a, b, M_sparse, log=True) |
| 71 | + |
| 72 | +print("Sparse OT cost:", log["cost"]) |
| 73 | +print("Edges:") |
| 74 | +for i, j, v in zip(log["flow_sources"], log["flow_targets"], log["flow_values"]): |
| 75 | + print(f" source {i} -> target {j}, flow={v:.3f}") |
| 76 | + |
| 77 | + |
| 78 | +############################################################################## |
| 79 | +# Visualize allowed edges |
| 80 | +# --------------------------------- |
| 81 | +# |
| 82 | +# The sparse cost matrix only allows transport along specific edges. |
| 83 | + |
| 84 | +# %% |
| 85 | + |
| 86 | + |
| 87 | +plt.figure(figsize=(8, 4)) |
| 88 | + |
| 89 | +# Sparse OT: allowed edges only |
| 90 | +plt.subplot(1, 2, 1) |
| 91 | +plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) |
| 92 | +plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) |
| 93 | +for i, j in zip(rows, cols): |
| 94 | + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.6) |
| 95 | +plt.title("Sparse OT: Allowed Edges Only") |
| 96 | + |
| 97 | +plt.xlim(-0.5, 1.5) |
| 98 | +plt.ylim(-0.5, 1.5) |
| 99 | +plt.xticks([0, 1]) |
| 100 | +plt.yticks([0, 1]) |
| 101 | + |
| 102 | +# Dense OT: all possible edges |
| 103 | +plt.subplot(1, 2, 2) |
| 104 | +plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) |
| 105 | +plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) |
| 106 | +for i in range(2): |
| 107 | + for j in range(2): |
| 108 | + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.3) |
| 109 | +plt.title("Dense OT: All Possible Edges") |
| 110 | +plt.xlim(-0.5, 1.5) |
| 111 | +plt.ylim(-0.5, 1.5) |
| 112 | +plt.xticks([0, 1]) |
| 113 | +plt.yticks([0, 1]) |
| 114 | + |
| 115 | +plt.tight_layout() |
| 116 | + |
| 117 | + |
| 118 | +############################################################################## |
| 119 | +# Larger example with clusters |
| 120 | +# -------------------------------------- |
| 121 | +# |
| 122 | +# Now we create a more realistic example with multiple clusters of sources |
| 123 | +# and targets, where transport is only allowed within each cluster. |
| 124 | + |
| 125 | +# %% |
| 126 | + |
| 127 | +grid_size = 4 |
| 128 | +n_clusters = grid_size * grid_size |
| 129 | +points_per_cluster = 2 |
| 130 | +cluster_spacing = 15.0 |
| 131 | +intra_cluster_spacing = 1.5 |
| 132 | +cluster_centers = ( |
| 133 | + np.array([[i, j] for i in range(grid_size) for j in range(grid_size)]) |
| 134 | + * cluster_spacing |
| 135 | +) |
| 136 | + |
| 137 | +X_large = [] |
| 138 | +Y_large = [] |
| 139 | +a_large = [] |
| 140 | +b_large = [] |
| 141 | + |
| 142 | +for idx, (cx, cy) in enumerate(cluster_centers): |
| 143 | + for i in range(points_per_cluster): |
| 144 | + X_large.append( |
| 145 | + [cx + intra_cluster_spacing * (i - 1), cy - intra_cluster_spacing] |
| 146 | + ) |
| 147 | + a_large.append(1.0 / (n_clusters * points_per_cluster)) |
| 148 | + |
| 149 | + for i in range(points_per_cluster): |
| 150 | + Y_large.append( |
| 151 | + [cx + intra_cluster_spacing * (i - 1), cy + intra_cluster_spacing] |
| 152 | + ) |
| 153 | + b_large.append(1.0 / (n_clusters * points_per_cluster)) |
| 154 | + |
| 155 | +X_large = np.array(X_large) |
| 156 | +Y_large = np.array(Y_large) |
| 157 | +a_large = np.array(a_large) |
| 158 | +b_large = np.array(b_large) |
| 159 | + |
| 160 | +nA = nB = n_clusters * points_per_cluster |
| 161 | +source_labels = np.repeat(np.arange(n_clusters), points_per_cluster) |
| 162 | +sink_labels = np.repeat(np.arange(n_clusters), points_per_cluster) |
| 163 | + |
| 164 | + |
| 165 | +############################################################################## |
| 166 | +# Build sparse cost matrix (intra-cluster only) |
| 167 | +# ---------------------------------------------- |
| 168 | +# |
| 169 | +# We construct a sparse cost matrix that only includes edges within each cluster. |
| 170 | + |
| 171 | +# %% |
| 172 | + |
| 173 | +M_full = ot.dist(X_large, Y_large, metric="euclidean") |
| 174 | + |
| 175 | +rows = [] |
| 176 | +cols = [] |
| 177 | +vals = [] |
| 178 | +for k in range(n_clusters): |
| 179 | + src_idx = np.where(source_labels == k)[0] |
| 180 | + sink_idx = np.where(sink_labels == k)[0] |
| 181 | + for i in src_idx: |
| 182 | + for j in sink_idx: |
| 183 | + rows.append(i) |
| 184 | + cols.append(j) |
| 185 | + vals.append(M_full[i, j]) |
| 186 | +M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(nA, nB)) |
| 187 | + |
| 188 | + |
| 189 | +############################################################################## |
| 190 | +# Visualize allowed edges structure |
| 191 | +# ---------------------------------- |
| 192 | +# |
| 193 | +# Dense OT allows all connections, while sparse OT restricts to intra-cluster edges. |
| 194 | + |
| 195 | +# %% |
| 196 | + |
| 197 | +plt.figure(figsize=(16, 6)) |
| 198 | + |
| 199 | +# Dense OT: all possible edges |
| 200 | +plt.subplot(1, 2, 1) |
| 201 | +for i in range(nA): |
| 202 | + for j in range(nB): |
| 203 | + plt.plot( |
| 204 | + [X_large[i, 0], Y_large[j, 0]], |
| 205 | + [X_large[i, 1], Y_large[j, 1]], |
| 206 | + color="blue", |
| 207 | + alpha=0.1, |
| 208 | + linewidth=0.7, |
| 209 | + ) |
| 210 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
| 211 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 212 | +plt.axis("equal") |
| 213 | +plt.title("Dense OT: All Possible Edges") |
| 214 | + |
| 215 | +# Sparse OT: only intra-cluster edges |
| 216 | +plt.subplot(1, 2, 2) |
| 217 | +for k in range(n_clusters): |
| 218 | + src_idx = np.where(source_labels == k)[0] |
| 219 | + sink_idx = np.where(sink_labels == k)[0] |
| 220 | + for i in src_idx: |
| 221 | + for j in sink_idx: |
| 222 | + plt.plot( |
| 223 | + [X_large[i, 0], Y_large[j, 0]], |
| 224 | + [X_large[i, 1], Y_large[j, 1]], |
| 225 | + color="blue", |
| 226 | + alpha=0.7, |
| 227 | + linewidth=1.5, |
| 228 | + ) |
| 229 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
| 230 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 231 | +plt.axis("equal") |
| 232 | +plt.title("Sparse OT: Only Intra-Cluster Edges") |
| 233 | + |
| 234 | +plt.tight_layout() |
| 235 | + |
| 236 | + |
| 237 | +############################################################################## |
| 238 | +# Solve and compare sparse vs dense OT |
| 239 | +# ------------------------------------- |
| 240 | +# |
| 241 | +# We solve both dense and sparse OT problems and verify that they produce |
| 242 | +# the same optimal solution when the sparse edges include the optimal paths. |
| 243 | + |
| 244 | +# %% |
| 245 | + |
| 246 | +# Solve dense OT (full cost matrix) |
| 247 | +G_dense = ot.emd(a_large, b_large, M_full) |
| 248 | +cost_dense = np.sum(G_dense * M_full) |
| 249 | +print(f"Dense OT cost: {cost_dense:.6f}") |
| 250 | + |
| 251 | +# Solve sparse OT (intra-cluster only) |
| 252 | +G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True) |
| 253 | +cost_sparse = log_sparse["cost"] |
| 254 | +print(f"Sparse OT cost: {cost_sparse:.6f}") |
| 255 | + |
| 256 | + |
| 257 | +############################################################################## |
| 258 | +# Visualize optimal transport plans |
| 259 | +# ---------------------------------- |
| 260 | +# |
| 261 | +# Plot the edges that carry flow in the optimal solutions. |
| 262 | + |
| 263 | +# %% |
| 264 | + |
| 265 | +plt.figure(figsize=(16, 6)) |
| 266 | + |
| 267 | +# Dense OT |
| 268 | +plt.subplot(1, 2, 1) |
| 269 | +for i in range(nA): |
| 270 | + for j in range(nB): |
| 271 | + if G_dense[i, j] > 1e-10: |
| 272 | + plt.plot( |
| 273 | + [X_large[i, 0], Y_large[j, 0]], |
| 274 | + [X_large[i, 1], Y_large[j, 1]], |
| 275 | + color="blue", |
| 276 | + alpha=0.7, |
| 277 | + linewidth=1.5, |
| 278 | + ) |
| 279 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
| 280 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 281 | +plt.axis("equal") |
| 282 | +plt.title("Dense OT: Optimal Transport Plan") |
| 283 | + |
| 284 | +# Sparse OT |
| 285 | +plt.subplot(1, 2, 2) |
| 286 | +if log_sparse["flow_sources"] is not None: |
| 287 | + for i, j, v in zip( |
| 288 | + log_sparse["flow_sources"], |
| 289 | + log_sparse["flow_targets"], |
| 290 | + log_sparse["flow_values"], |
| 291 | + ): |
| 292 | + if v > 1e-10: |
| 293 | + plt.plot( |
| 294 | + [X_large[i, 0], Y_large[j, 0]], |
| 295 | + [X_large[i, 1], Y_large[j, 1]], |
| 296 | + color="blue", |
| 297 | + alpha=0.7, |
| 298 | + linewidth=1.5, |
| 299 | + ) |
| 300 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
| 301 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 302 | +plt.axis("equal") |
| 303 | +plt.title("Sparse OT: Optimal Transport Plan") |
| 304 | + |
| 305 | +plt.tight_layout() |
| 306 | +plt.show() |
0 commit comments