Skip to content

Commit 975ca6c

Browse files
committed
Replaced coo_matrix with coo_array better compatability and added test to test coo_array functionnality
1 parent 9412193 commit 975ca6c

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

ot/backend.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
import scipy
9595
import scipy.linalg
9696
import scipy.special as special
97-
from scipy.sparse import coo_matrix, csr_matrix, issparse
97+
from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse
9898

9999
DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
100100
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
@@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
802802
r"""
803803
Creates a sparse tensor in COOrdinate format.
804804
805-
This function follows the api from :any:`scipy.sparse.coo_matrix`
805+
This function follows the api from :any:`scipy.sparse.coo_array`
806806
807-
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
807+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html
808808
"""
809809
raise NotImplementedError()
810810

@@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None):
13541354

13551355
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
13561356
if type_as is None:
1357-
return coo_matrix((data, (rows, cols)), shape=shape)
1357+
return coo_array((data, (rows, cols)), shape=shape)
13581358
else:
1359-
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
1359+
return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
13601360

13611361
def issparse(self, a):
13621362
return issparse(a)
@@ -1385,8 +1385,9 @@ def todense(self, a):
13851385

13861386
def sparse_coo_data(self, a):
13871387
# Convert to COO format if needed
1388-
if not isinstance(a, coo_matrix):
1389-
a_coo = coo_matrix(a)
1388+
if not isinstance(a, (coo_array, coo_matrix)):
1389+
# Try to convert to coo_array (prefer modern API)
1390+
a_coo = coo_array(a)
13901391
else:
13911392
a_coo = a
13921393

@@ -1815,9 +1816,7 @@ def sparse_coo_data(self, a):
18151816
# JAX doesn't support sparse matrices, so this shouldn't be called
18161817
# But if it is, convert the dense array to sparse using scipy
18171818
a_np = self.to_numpy(a)
1818-
from scipy.sparse import coo_matrix
1819-
1820-
a_coo = coo_matrix(a_np)
1819+
a_coo = coo_array(a_np)
18211820
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape
18221821

18231822
def where(self, condition, x=None, y=None):

test/test_ot.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,15 @@ def test_emd_sparse_vs_dense(nx):
992992
b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7
993993
)
994994

995+
# Test coo_array element-wise multiplication (only works with coo_array, not coo_matrix)
996+
if nx.__name__ == "numpy":
997+
# This tests that we're using coo_array which supports element-wise operations
998+
M_sparse_np = M_sparse
999+
G_sparse_np = G_sparse
1000+
loss_sparse = np.sum(G_sparse_np * M_sparse_np)
1001+
# Verify the loss calculation is reasonable
1002+
assert loss_sparse >= 0, "Sparse loss should be non-negative"
1003+
9951004

9961005
def test_emd2_sparse_vs_dense(nx):
9971006
"""Test that sparse and dense emd2 solvers produce identical costs.

0 commit comments

Comments
 (0)