Skip to content

Commit 65f7063

Browse files
authored
Merge branch 'master' into dev_sliced_plans
2 parents a19753e + 9412193 commit 65f7063

File tree

17 files changed

+1415
-121
lines changed

17 files changed

+1415
-121
lines changed

.github/workflows/build_tests.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727

2828

2929
- name: Checking Out Repository
30-
uses: actions/checkout@v2
30+
uses: actions/checkout@v4
3131
# Install Python & Packages
3232
- uses: actions/setup-python@v4
3333
with:
@@ -39,6 +39,20 @@ jobs:
3939
pre-commit install --install-hooks
4040
pre-commit run --all-files
4141
42+
build_from_source:
43+
runs-on: ubuntu-latest
44+
steps:
45+
- uses: actions/checkout@v4
46+
- name: Set up Python
47+
uses: actions/setup-python@v5
48+
with:
49+
python-version: "3.12"
50+
- name: Build from source
51+
run: |
52+
python -m pip install --upgrade pip setuptools wheel
53+
python -m pip install cython numpy
54+
python setup.py sdist bdist_wheel
55+
pip install dist/*.tar.gz
4256
4357
linux:
4458

CITATION.cff

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ authors:
7272
family-names: Fernandes Montesuma
7373
affiliation: Université Paris-Saclay & CEA-List
7474
orcid: 'https://orcid.org/0000-0003-3850-4602'
75+
- given-names: Nathan
76+
family-names: Neike
77+
affiliation: Hi! PARIS
7578
identifiers:
7679
- type: url
7780
value: 'https://github.com/PythonOT/POT'

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ The contributors to this library are:
5858
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)
5959
* [Julie Delon](https://judelo.github.io/) (GMM OT)
6060
* [Samuel Boïté](https://samuelbx.github.io/) (GMM OT)
61+
* [Nathan Neike](https://github.com/nathanneike) (Sparse EMD solver)
6162

6263

6364
## Acknowledgments

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h
1010
include ot/lp/full_bipartitegraph_omp.h
1111
include ot/lp/network_simplex_simple.h
1212
include ot/lp/network_simplex_simple_omp.h
13+
include ot/lp/sparse_bipartitegraph.h
1314
include ot/partial/partial_cython.pyx

examples/plot_sparse_emd.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================================
4+
Sparse Optimal Transport
5+
============================================
6+
7+
In many real-world optimal transport (OT) problems, the transport plan is
8+
naturally sparse: only a small fraction of all possible source-target pairs
9+
actually exchange mass. Using sparse OT solvers can provide significant
10+
computational speedups and memory savings compared to dense solvers.
11+
12+
This example demonstrates how to use sparse cost matrices with POT's EMD solver,
13+
comparing sparse and dense formulations on both a minimal example and a larger
14+
concentric circles dataset.
15+
"""
16+
17+
# Author: Nathan Neike
18+
#
19+
# License: MIT License
20+
# sphinx_gallery_thumbnail_number = 2
21+
22+
import numpy as np
23+
import matplotlib.pyplot as plt
24+
from scipy.sparse import coo_array
25+
import ot
26+
27+
##############################################################################
28+
# Example: concentric circles
29+
# -----------------------------------
30+
31+
# %%
32+
33+
n_clusters = 8
34+
points_per_cluster = 25
35+
n = n_clusters * points_per_cluster
36+
k_neighbors = 8
37+
rng = np.random.default_rng(0)
38+
39+
r_source = 1.0
40+
r_target = 2.0
41+
noise_scale = 0.06
42+
43+
theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False)
44+
cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
45+
46+
X_large = np.column_stack(
47+
[r_source * np.cos(theta), r_source * np.sin(theta)]
48+
) + rng.normal(scale=noise_scale, size=(n, 2))
49+
Y_large = np.column_stack(
50+
[r_target * np.cos(theta), r_target * np.sin(theta)]
51+
) + rng.normal(scale=noise_scale, size=(n, 2))
52+
53+
a_large = np.zeros(n)
54+
b_large = np.zeros(n)
55+
for k in range(n_clusters):
56+
idx = np.where(cluster_labels == k)[0]
57+
a_large[idx] = 1.0 / n_clusters / points_per_cluster
58+
b_large[idx] = 1.0 / n_clusters / points_per_cluster
59+
60+
M_full = ot.dist(X_large, Y_large, metric="euclidean")
61+
62+
# Build sparse cost matrix: intra-cluster k-nearest neighbors
63+
angles_X = np.arctan2(X_large[:, 1], X_large[:, 0])
64+
angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0])
65+
66+
rows = []
67+
cols = []
68+
vals = []
69+
for k in range(n_clusters):
70+
src_idx = np.where(cluster_labels == k)[0]
71+
tgt_idx = np.where(cluster_labels == k)[0]
72+
for i in src_idx:
73+
diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i])))
74+
idx = np.argsort(np.abs(diff))[:k_neighbors]
75+
for j_local in idx:
76+
j = tgt_idx[j_local]
77+
rows.append(i)
78+
cols.append(j)
79+
vals.append(M_full[i, j])
80+
81+
M_sparse_large = coo_array((vals, (rows, cols)), shape=(n, n))
82+
allowed_sparse = set(zip(rows, cols))
83+
84+
##############################################################################
85+
# Visualize edge structures
86+
# --------------------------
87+
88+
# %%
89+
90+
plt.figure(figsize=(16, 6))
91+
92+
plt.subplot(1, 2, 1)
93+
for i in range(n):
94+
for j in range(n):
95+
plt.plot(
96+
[X_large[i, 0], Y_large[j, 0]],
97+
[X_large[i, 1], Y_large[j, 1]],
98+
color="blue",
99+
alpha=0.2,
100+
linewidth=0.05,
101+
)
102+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
103+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
104+
plt.axis("equal")
105+
plt.title("Dense OT: All Possible Edges")
106+
107+
plt.subplot(1, 2, 2)
108+
for i, j in allowed_sparse:
109+
plt.plot(
110+
[X_large[i, 0], Y_large[j, 0]],
111+
[X_large[i, 1], Y_large[j, 1]],
112+
color="blue",
113+
alpha=1,
114+
linewidth=0.05,
115+
)
116+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
117+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
118+
plt.axis("equal")
119+
plt.title("Sparse OT: Intra-Cluster k-NN Edges")
120+
121+
plt.tight_layout()
122+
plt.show()
123+
124+
##############################################################################
125+
# Solve and visualize transport plans
126+
# ------------------------------------
127+
128+
# %%
129+
130+
G_dense = ot.emd(a_large, b_large, M_full)
131+
cost_dense = np.sum(G_dense * M_full)
132+
print(f"Dense OT cost: {cost_dense:.6f}")
133+
134+
G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True)
135+
cost_sparse = log_sparse["cost"]
136+
print(f"Sparse OT cost: {cost_sparse:.6f}")
137+
138+
plt.figure(figsize=(16, 6))
139+
140+
plt.subplot(1, 2, 1)
141+
ot.plot.plot2D_samples_mat(
142+
X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
143+
)
144+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
145+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
146+
plt.axis("equal")
147+
plt.title("Dense OT: Optimal Transport Plan")
148+
149+
plt.subplot(1, 2, 2)
150+
ot.plot.plot2D_samples_mat(
151+
X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
152+
)
153+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
154+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
155+
plt.axis("equal")
156+
plt.title("Sparse OT: Optimal Transport Plan")
157+
158+
plt.tight_layout()
159+
plt.show()

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
# utils functions
8080
from .utils import dist, unif, tic, toc, toq
8181

82-
__version__ = "0.9.6.post1"
82+
__version__ = "0.9.7.dev0"
8383

8484
__all__ = [
8585
"emd",

ot/backend.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
import jax
120120
import jax.numpy as jnp
121121
import jax.scipy.special as jspecial
122-
from jax.lib import xla_bridge
122+
from jax.extend.backend import get_backend as _jax_get_backend
123123

124124
jax_type = jax.numpy.ndarray
125125
jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24
@@ -178,7 +178,16 @@ def _get_backend_instance(backend_impl):
178178

179179

180180
def _check_args_backend(backend_impl, args):
181-
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
181+
# Get backend instance to use issparse method
182+
backend = _get_backend_instance(backend_impl)
183+
184+
# Check if each arg is either:
185+
# 1. An instance of backend.__type__ (e.g., np.ndarray for NumPy)
186+
# 2. A sparse matrix recognized by backend.issparse() (e.g., scipy.sparse for NumPy)
187+
is_instance = set(
188+
isinstance(arg, backend_impl.__type__) or backend.issparse(arg) for arg in args
189+
)
190+
182191
# check that all arguments matched or not the type
183192
if len(is_instance) == 1:
184193
return is_instance.pop()
@@ -839,6 +848,31 @@ def todense(self, a):
839848
"""
840849
raise NotImplementedError()
841850

851+
def sparse_coo_data(self, a):
852+
r"""
853+
Extracts COO format data (row, col, data, shape) from a sparse matrix.
854+
855+
Returns row indices, column indices, data values, and shape as numpy arrays/tuple.
856+
This is used to interface with C++ solvers that require explicit edge lists.
857+
858+
Parameters
859+
----------
860+
a : sparse matrix
861+
Sparse matrix in backend's COO format
862+
863+
Returns
864+
-------
865+
row : numpy.ndarray
866+
Row indices (1D array)
867+
col : numpy.ndarray
868+
Column indices (1D array)
869+
data : numpy.ndarray
870+
Data values (1D array)
871+
shape : tuple
872+
Shape of the matrix (n_rows, n_cols)
873+
"""
874+
raise NotImplementedError()
875+
842876
def where(self, condition, x, y):
843877
r"""
844878
Returns elements chosen from x or y depending on condition.
@@ -1349,6 +1383,15 @@ def todense(self, a):
13491383
else:
13501384
return a
13511385

1386+
def sparse_coo_data(self, a):
1387+
# Convert to COO format if needed
1388+
if not isinstance(a, coo_matrix):
1389+
a_coo = coo_matrix(a)
1390+
else:
1391+
a_coo = a
1392+
1393+
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape
1394+
13521395
def where(self, condition, x=None, y=None):
13531396
if x is None and y is None:
13541397
return np.where(condition)
@@ -1509,7 +1552,7 @@ def __init__(self):
15091552
self.__type_list__ = []
15101553
# available_devices = jax.devices("cpu")
15111554
available_devices = []
1512-
if xla_bridge.get_backend().platform == "gpu":
1555+
if _jax_get_backend().platform == "gpu":
15131556
available_devices += jax.devices("gpu")
15141557
for d in available_devices:
15151558
self.__type_list__ += [
@@ -1768,6 +1811,15 @@ def todense(self, a):
17681811
# Currently, JAX does not support sparse matrices
17691812
return a
17701813

1814+
def sparse_coo_data(self, a):
1815+
# JAX doesn't support sparse matrices, so this shouldn't be called
1816+
# But if it is, convert the dense array to sparse using scipy
1817+
a_np = self.to_numpy(a)
1818+
from scipy.sparse import coo_matrix
1819+
1820+
a_coo = coo_matrix(a_np)
1821+
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape
1822+
17711823
def where(self, condition, x=None, y=None):
17721824
if x is None and y is None:
17731825
return jnp.where(condition)
@@ -1938,6 +1990,7 @@ def __init__(self):
19381990
self.rng_cuda_ = torch.Generator("cpu")
19391991

19401992
from torch.autograd import Function
1993+
from torch.autograd.function import once_differentiable
19411994

19421995
# define a function that takes inputs val and grads
19431996
# ad returns a val tensor with proper gradients
@@ -1952,7 +2005,31 @@ def backward(ctx, grad_output):
19522005
# the gradients are grad
19532006
return (None, None) + tuple(g * grad_output for g in ctx.grads)
19542007

2008+
# define a differentiable SPD matrix sqrt
2009+
# with closed-form VJP
2010+
class MatrixSqrtFunction(Function):
2011+
@staticmethod
2012+
def forward(ctx, a):
2013+
a_sym = 0.5 * (a + a.transpose(-2, -1))
2014+
L, V = torch.linalg.eigh(a_sym)
2015+
s = L.clamp_min(0).sqrt()
2016+
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
2017+
ctx.save_for_backward(s, V)
2018+
return y
2019+
2020+
@staticmethod
2021+
@once_differentiable
2022+
def backward(ctx, g):
2023+
s, V = ctx.saved_tensors
2024+
g_sym = 0.5 * (g + g.transpose(-2, -1))
2025+
ghat = V.transpose(-2, -1) @ g_sym @ V
2026+
d = s.unsqueeze(-1) + s.unsqueeze(-2)
2027+
xhat = ghat / d
2028+
xhat = xhat.masked_fill(d == 0, 0)
2029+
return V @ xhat @ V.transpose(-2, -1)
2030+
19552031
self.ValFunction = ValFunction
2032+
self.MatrixSqrtFunction = MatrixSqrtFunction
19562033

19572034
def _to_numpy(self, a):
19582035
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
@@ -2315,6 +2392,20 @@ def todense(self, a):
23152392
else:
23162393
return a
23172394

2395+
def sparse_coo_data(self, a):
2396+
# For torch sparse tensors, coalesce first to ensure unique indices
2397+
a_coalesced = a.coalesce()
2398+
indices = a_coalesced._indices()
2399+
values = a_coalesced._values()
2400+
2401+
# Convert to numpy
2402+
row = self.to_numpy(indices[0])
2403+
col = self.to_numpy(indices[1])
2404+
data = self.to_numpy(values)
2405+
shape = tuple(a_coalesced.shape)
2406+
2407+
return row, col, data, shape
2408+
23182409
def where(self, condition, x=None, y=None):
23192410
if x is None and y is None:
23202411
return torch.where(condition)
@@ -2395,12 +2486,7 @@ def pinv(self, a, hermitian=False):
23952486
return torch.linalg.pinv(a, hermitian=hermitian)
23962487

23972488
def sqrtm(self, a):
2398-
L, V = torch.linalg.eigh(a)
2399-
L = torch.sqrt(L)
2400-
# Q[...] = V[...] @ diag(L[...])
2401-
Q = torch.einsum("...jk,...k->...jk", V, L)
2402-
# R[...] = Q[...] @ V[...].T
2403-
return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2))
2489+
return self.MatrixSqrtFunction.apply(a)
24042490

24052491
def eigh(self, a):
24062492
return torch.linalg.eigh(a)

0 commit comments

Comments
 (0)