Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver.

#### New features
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)

#### Closed issues
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
Expand Down
19 changes: 9 additions & 10 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import coo_matrix, csr_matrix, issparse
from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse
from scipy.sparse import coo_array, csr_array, issparse

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onl arrays should be used


DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
Expand Down Expand Up @@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
r"""
Creates a sparse tensor in COOrdinate format.

This function follows the api from :any:`scipy.sparse.coo_matrix`
This function follows the api from :any:`scipy.sparse.coo_array`

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html
"""
raise NotImplementedError()

Expand Down Expand Up @@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None):

def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return coo_matrix((data, (rows, cols)), shape=shape)
return coo_array((data, (rows, cols)), shape=shape)
else:
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype)

def issparse(self, a):
return issparse(a)
Expand Down Expand Up @@ -1385,8 +1385,9 @@ def todense(self, a):

def sparse_coo_data(self, a):
# Convert to COO format if needed
if not isinstance(a, coo_matrix):
a_coo = coo_matrix(a)
if not isinstance(a, (coo_array, coo_matrix)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(a, (coo_array, coo_matrix)):
if not isinstance(a, coo_array):

# Try to convert to coo_array (prefer modern API)
a_coo = coo_array(a)
else:
a_coo = a

Expand Down Expand Up @@ -1815,9 +1816,7 @@ def sparse_coo_data(self, a):
# JAX doesn't support sparse matrices, so this shouldn't be called
# But if it is, convert the dense array to sparse using scipy
a_np = self.to_numpy(a)
from scipy.sparse import coo_matrix

a_coo = coo_matrix(a_np)
a_coo = coo_array(a_np)
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape

def where(self, condition, x=None, y=None):
Expand Down
9 changes: 9 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,15 @@ def test_emd_sparse_vs_dense(nx):
b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7
)

# Test coo_array element-wise multiplication (only works with coo_array, not coo_matrix)
if nx.__name__ == "numpy":
# This tests that we're using coo_array which supports element-wise operations
M_sparse_np = M_sparse
G_sparse_np = G_sparse
loss_sparse = np.sum(G_sparse_np * M_sparse_np)
# Verify the loss calculation is reasonable
assert loss_sparse >= 0, "Sparse loss should be non-negative"


def test_emd2_sparse_vs_dense(nx):
"""Test that sparse and dense emd2 solvers produce identical costs.
Expand Down
Loading