|
94 | 94 | import scipy |
95 | 95 | import scipy.linalg |
96 | 96 | import scipy.special as special |
97 | | -from scipy.sparse import coo_matrix, csr_matrix, issparse |
| 97 | +from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse |
98 | 98 |
|
99 | 99 | DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH" |
100 | 100 | DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX" |
@@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): |
802 | 802 | r""" |
803 | 803 | Creates a sparse tensor in COOrdinate format. |
804 | 804 |
|
805 | | - This function follows the api from :any:`scipy.sparse.coo_matrix` |
| 805 | + This function follows the api from :any:`scipy.sparse.coo_array` |
806 | 806 |
|
807 | | - See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html |
| 807 | + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html |
808 | 808 | """ |
809 | 809 | raise NotImplementedError() |
810 | 810 |
|
@@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None): |
1354 | 1354 |
|
1355 | 1355 | def coo_matrix(self, data, rows, cols, shape=None, type_as=None): |
1356 | 1356 | if type_as is None: |
1357 | | - return coo_matrix((data, (rows, cols)), shape=shape) |
| 1357 | + return coo_array((data, (rows, cols)), shape=shape) |
1358 | 1358 | else: |
1359 | | - return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) |
| 1359 | + return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype) |
1360 | 1360 |
|
1361 | 1361 | def issparse(self, a): |
1362 | 1362 | return issparse(a) |
@@ -1385,8 +1385,9 @@ def todense(self, a): |
1385 | 1385 |
|
1386 | 1386 | def sparse_coo_data(self, a): |
1387 | 1387 | # Convert to COO format if needed |
1388 | | - if not isinstance(a, coo_matrix): |
1389 | | - a_coo = coo_matrix(a) |
| 1388 | + if not isinstance(a, (coo_array, coo_matrix)): |
| 1389 | + # Try to convert to coo_array (prefer modern API) |
| 1390 | + a_coo = coo_array(a) |
1390 | 1391 | else: |
1391 | 1392 | a_coo = a |
1392 | 1393 |
|
@@ -1815,9 +1816,7 @@ def sparse_coo_data(self, a): |
1815 | 1816 | # JAX doesn't support sparse matrices, so this shouldn't be called |
1816 | 1817 | # But if it is, convert the dense array to sparse using scipy |
1817 | 1818 | a_np = self.to_numpy(a) |
1818 | | - from scipy.sparse import coo_matrix |
1819 | | - |
1820 | | - a_coo = coo_matrix(a_np) |
| 1819 | + a_coo = coo_array(a_np) |
1821 | 1820 | return a_coo.row, a_coo.col, a_coo.data, a_coo.shape |
1822 | 1821 |
|
1823 | 1822 | def where(self, condition, x=None, y=None): |
|
0 commit comments