@@ -914,6 +914,43 @@ def test_dual_variables():
914914 assert constraint_violation .max () < 1e-8
915915
916916
917+ def _get_sparse_test_matrices (n1 , n2 , k = 2 , seed = 42 , nx = None ):
918+ """Helper function to create sparse and dense test matrices."""
919+ from scipy .sparse import coo_array
920+ from ot .backend import NumpyBackend
921+
922+ if nx is None :
923+ nx = NumpyBackend ()
924+
925+ rng = np .random .RandomState (seed )
926+ M_orig = rng .rand (n1 , n2 )
927+
928+ mask = np .zeros ((n1 , n2 ))
929+ for i in range (n1 ):
930+ j_list = rng .choice (n2 , min (k , n2 ), replace = False )
931+ for j in j_list :
932+ mask [i , j ] = 1
933+ for j in range (n2 ):
934+ i_list = rng .choice (n1 , min (k , n1 ), replace = False )
935+ for i in i_list :
936+ mask [i , j ] = 1
937+
938+ M_sparse_np = coo_array (M_orig * mask )
939+ rows , cols , data = M_sparse_np .row , M_sparse_np .col , M_sparse_np .data
940+
941+ if nx .__name__ == "numpy" :
942+ M_sparse = M_sparse_np
943+ else :
944+ rows_b = nx .from_numpy (rows .astype (np .int64 ))
945+ cols_b = nx .from_numpy (cols .astype (np .int64 ))
946+ data_b = nx .from_numpy (data )
947+ M_sparse = nx .coo_matrix (data_b , rows_b , cols_b , shape = (n1 , n2 ))
948+
949+ M_dense = nx .from_numpy (M_orig + 1e8 * (1 - mask ))
950+
951+ return M_sparse , M_dense
952+
953+
917954def test_emd_sparse_vs_dense (nx ):
918955 """Test that sparse and dense EMD solvers produce identical results.
919956
@@ -929,7 +966,7 @@ def test_emd_sparse_vs_dense(nx):
929966 n2 = 100
930967 k = 2
931968
932- M_sparse , M_dense = ot . utils . get_sparse_test_matrices (n1 , n2 , k = k , seed = 42 , nx = nx )
969+ M_sparse , M_dense = _get_sparse_test_matrices (n1 , n2 , k = k , seed = 42 , nx = nx )
933970
934971 a = ot .utils .unif (n1 , type_as = M_dense )
935972 b = ot .utils .unif (n2 , type_as = M_dense )
@@ -971,7 +1008,7 @@ def test_emd2_sparse_vs_dense(nx):
9711008 n2 = 150
9721009 k = 2
9731010
974- M_sparse , M_dense = ot . utils . get_sparse_test_matrices (n1 , n2 , k = k , seed = 43 , nx = nx )
1011+ M_sparse , M_dense = _get_sparse_test_matrices (n1 , n2 , k = k , seed = 43 , nx = nx )
9751012
9761013 a = ot .utils .unif (n1 , type_as = M_dense )
9771014 b = ot .utils .unif (n2 , type_as = M_dense )
0 commit comments