Skip to content

Commit f8ca89e

Browse files
committed
moved sparse matrix generation code to test file out of utils
1 parent a298a86 commit f8ca89e

File tree

2 files changed

+39
-36
lines changed

2 files changed

+39
-36
lines changed

ot/utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414
from scipy.spatial.distance import cdist
15-
from scipy.sparse import coo_array
1615
import sys
1716
import warnings
1817
from inspect import signature
@@ -435,39 +434,6 @@ def dist(
435434
return cdist(x1, x2, metric=metric)
436435

437436

438-
def get_sparse_test_matrices(n1, n2, k=2, seed=42, nx=None):
439-
if nx is None:
440-
nx = NumpyBackend()
441-
442-
rng = np.random.RandomState(seed)
443-
M_orig = rng.rand(n1, n2)
444-
445-
mask = np.zeros((n1, n2))
446-
for i in range(n1):
447-
j_list = rng.choice(n2, min(k, n2), replace=False)
448-
for j in j_list:
449-
mask[i, j] = 1
450-
for j in range(n2):
451-
i_list = rng.choice(n1, min(k, n1), replace=False)
452-
for i in i_list:
453-
mask[i, j] = 1
454-
455-
M_sparse_np = coo_array(M_orig * mask)
456-
rows, cols, data = M_sparse_np.row, M_sparse_np.col, M_sparse_np.data
457-
458-
if nx.__name__ == "numpy":
459-
M_sparse = M_sparse_np
460-
else:
461-
rows_b = nx.from_numpy(rows.astype(np.int64))
462-
cols_b = nx.from_numpy(cols.astype(np.int64))
463-
data_b = nx.from_numpy(data)
464-
M_sparse = nx.coo_matrix(data_b, rows_b, cols_b, shape=(n1, n2))
465-
466-
M_dense = nx.from_numpy(M_orig + 1e8 * (1 - mask))
467-
468-
return M_sparse, M_dense
469-
470-
471437
def dist0(n, method="lin_square"):
472438
r"""Compute standard cost matrices of size (`n`, `n`) for OT problems
473439

test/test_ot.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,43 @@ def test_dual_variables():
914914
assert constraint_violation.max() < 1e-8
915915

916916

917+
def _get_sparse_test_matrices(n1, n2, k=2, seed=42, nx=None):
918+
"""Helper function to create sparse and dense test matrices."""
919+
from scipy.sparse import coo_array
920+
from ot.backend import NumpyBackend
921+
922+
if nx is None:
923+
nx = NumpyBackend()
924+
925+
rng = np.random.RandomState(seed)
926+
M_orig = rng.rand(n1, n2)
927+
928+
mask = np.zeros((n1, n2))
929+
for i in range(n1):
930+
j_list = rng.choice(n2, min(k, n2), replace=False)
931+
for j in j_list:
932+
mask[i, j] = 1
933+
for j in range(n2):
934+
i_list = rng.choice(n1, min(k, n1), replace=False)
935+
for i in i_list:
936+
mask[i, j] = 1
937+
938+
M_sparse_np = coo_array(M_orig * mask)
939+
rows, cols, data = M_sparse_np.row, M_sparse_np.col, M_sparse_np.data
940+
941+
if nx.__name__ == "numpy":
942+
M_sparse = M_sparse_np
943+
else:
944+
rows_b = nx.from_numpy(rows.astype(np.int64))
945+
cols_b = nx.from_numpy(cols.astype(np.int64))
946+
data_b = nx.from_numpy(data)
947+
M_sparse = nx.coo_matrix(data_b, rows_b, cols_b, shape=(n1, n2))
948+
949+
M_dense = nx.from_numpy(M_orig + 1e8 * (1 - mask))
950+
951+
return M_sparse, M_dense
952+
953+
917954
def test_emd_sparse_vs_dense(nx):
918955
"""Test that sparse and dense EMD solvers produce identical results.
919956
@@ -929,7 +966,7 @@ def test_emd_sparse_vs_dense(nx):
929966
n2 = 100
930967
k = 2
931968

932-
M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=42, nx=nx)
969+
M_sparse, M_dense = _get_sparse_test_matrices(n1, n2, k=k, seed=42, nx=nx)
933970

934971
a = ot.utils.unif(n1, type_as=M_dense)
935972
b = ot.utils.unif(n2, type_as=M_dense)
@@ -971,7 +1008,7 @@ def test_emd2_sparse_vs_dense(nx):
9711008
n2 = 150
9721009
k = 2
9731010

974-
M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=43, nx=nx)
1011+
M_sparse, M_dense = _get_sparse_test_matrices(n1, n2, k=k, seed=43, nx=nx)
9751012

9761013
a = ot.utils.unif(n1, type_as=M_dense)
9771014
b = ot.utils.unif(n2, type_as=M_dense)

0 commit comments

Comments
 (0)