Skip to content

Commit 5ee1a4d

Browse files
committed
Replaced some more coo_matrix calls
1 parent d889ac9 commit 5ee1a4d

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

ot/backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
import scipy
9595
import scipy.linalg
9696
import scipy.special as special
97-
from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse
97+
from scipy.sparse import coo_array, csr_matrix, issparse
9898

9999
DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
100100
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
@@ -1384,9 +1384,8 @@ def todense(self, a):
13841384
return a
13851385

13861386
def sparse_coo_data(self, a):
1387-
# Convert to COO format if needed
1388-
if not isinstance(a, (coo_array, coo_matrix)):
1389-
# Try to convert to coo_array (prefer modern API)
1387+
# Convert to COO array format if needed
1388+
if not isinstance(a, coo_array):
13901389
a_coo = coo_array(a)
13911390
else:
13921391
a_coo = a
@@ -2803,10 +2802,10 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
28032802
rows = self.from_numpy(rows)
28042803
cols = self.from_numpy(cols)
28052804
if type_as is None:
2806-
return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape)
2805+
return cupyx.scipy.sparse.coo_array((data, (rows, cols)), shape=shape)
28072806
else:
28082807
with cp.cuda.Device(type_as.device):
2809-
return cupyx.scipy.sparse.coo_matrix(
2808+
return cupyx.scipy.sparse.coo_array(
28102809
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
28112810
)
28122811

ot/plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import numpy as np
1616
import matplotlib.pylab as pl
1717
from matplotlib import gridspec
18+
from . import backend
19+
from scipy.sparse import issparse, coo_array
1820

1921

2022
def plot1D_mat(
@@ -232,8 +234,6 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
232234
parameters given to the plot functions (default color is black if
233235
nothing given)
234236
"""
235-
from . import backend
236-
from scipy.sparse import issparse, coo_matrix
237237

238238
if ("color" not in kwargs) and ("c" not in kwargs):
239239
kwargs["color"] = "k"
@@ -258,7 +258,7 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
258258
# Not a backend array, check if scipy.sparse
259259
is_sparse = issparse(G)
260260
if is_sparse:
261-
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
261+
G_coo = G if isinstance(G, coo_array) else G.tocoo()
262262
rows, cols, data = G_coo.row, G_coo.col, G_coo.data
263263

264264
if is_sparse:

0 commit comments

Comments
 (0)