Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver.

#### New features
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)

#### Closed issues
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
Expand Down
24 changes: 11 additions & 13 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import coo_matrix, csr_matrix, issparse
from scipy.sparse import coo_array, csr_matrix, issparse

DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
Expand Down Expand Up @@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
r"""
Creates a sparse tensor in COOrdinate format.

This function follows the api from :any:`scipy.sparse.coo_matrix`
This function follows the api from :any:`scipy.sparse.coo_array`

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html
"""
raise NotImplementedError()

Expand Down Expand Up @@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None):

def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return coo_matrix((data, (rows, cols)), shape=shape)
return coo_array((data, (rows, cols)), shape=shape)
else:
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype)

def issparse(self, a):
return issparse(a)
Expand Down Expand Up @@ -1384,9 +1384,9 @@ def todense(self, a):
return a

def sparse_coo_data(self, a):
# Convert to COO format if needed
if not isinstance(a, coo_matrix):
a_coo = coo_matrix(a)
# Convert to COO array format if needed
if not isinstance(a, coo_array):
a_coo = coo_array(a)
else:
a_coo = a

Expand Down Expand Up @@ -1815,9 +1815,7 @@ def sparse_coo_data(self, a):
# JAX doesn't support sparse matrices, so this shouldn't be called
# But if it is, convert the dense array to sparse using scipy
a_np = self.to_numpy(a)
from scipy.sparse import coo_matrix

a_coo = coo_matrix(a_np)
a_coo = coo_array(a_np)
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape

def where(self, condition, x=None, y=None):
Expand Down Expand Up @@ -2804,10 +2802,10 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
rows = self.from_numpy(rows)
cols = self.from_numpy(cols)
if type_as is None:
return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape)
return cupyx.scipy.sparse.coo_array((data, (rows, cols)), shape=shape)
else:
with cp.cuda.Device(type_as.device):
return cupyx.scipy.sparse.coo_matrix(
return cupyx.scipy.sparse.coo_array(
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
)

Expand Down
6 changes: 3 additions & 3 deletions ot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import matplotlib.pylab as pl
from matplotlib import gridspec
from . import backend
from scipy.sparse import issparse, coo_array


def plot1D_mat(
Expand Down Expand Up @@ -232,8 +234,6 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
parameters given to the plot functions (default color is black if
nothing given)
"""
from . import backend
from scipy.sparse import issparse, coo_matrix

if ("color" not in kwargs) and ("c" not in kwargs):
kwargs["color"] = "k"
Expand All @@ -258,7 +258,7 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
# Not a backend array, check if scipy.sparse
is_sparse = issparse(G)
if is_sparse:
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
G_coo = G if isinstance(G, coo_array) else G.tocoo()
rows, cols, data = G_coo.row, G_coo.col, G_coo.data

if is_sparse:
Expand Down
9 changes: 9 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,15 @@ def test_emd_sparse_vs_dense(nx):
b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7
)

# Test coo_array element-wise multiplication (only works with coo_array, not coo_matrix)
if nx.__name__ == "numpy":
# This tests that we're using coo_array which supports element-wise operations
M_sparse_np = M_sparse
G_sparse_np = G_sparse
loss_sparse = np.sum(G_sparse_np * M_sparse_np)
# Verify the loss calculation is reasonable
assert loss_sparse >= 0, "Sparse loss should be non-negative"


def test_emd2_sparse_vs_dense(nx):
"""Test that sparse and dense emd2 solvers produce identical costs.
Expand Down
Loading