|
12 | 12 | import ot |
13 | 13 | from ot.datasets import make_1D_gauss as gauss |
14 | 14 | from ot.backend import torch, tf, get_backend |
15 | | -from scipy.sparse import coo_array |
16 | 15 |
|
17 | 16 |
|
18 | 17 | def test_emd_dimension_and_mass_mismatch(): |
@@ -918,156 +917,70 @@ def test_dual_variables(): |
918 | 917 | def test_emd_sparse_vs_dense(nx): |
919 | 918 | """Test that sparse and dense EMD solvers produce identical results. |
920 | 919 |
|
921 | | - Uses augmented k-NN graph approach: first solves with dense solver to |
922 | | - identify needed edges, then compares both solvers on the same graph. |
| 920 | + Uses random sparse graphs with k=2 edges per row/column, which guarantees |
| 921 | + feasibility with uniform marginals. |
923 | 922 | """ |
924 | 923 | # Skip for backends that don't support sparse matrices |
925 | 924 | backend_name = nx.__class__.__name__.lower() |
926 | 925 | if "jax" in backend_name or "tensorflow" in backend_name: |
927 | 926 | pytest.skip("Backend does not support sparse matrices") |
928 | 927 |
|
929 | | - n_source = 100 |
930 | | - n_target = 100 |
931 | | - k = 10 |
| 928 | + n1 = 100 |
| 929 | + n2 = 100 |
| 930 | + k = 2 |
932 | 931 |
|
933 | | - rng = np.random.RandomState(42) |
| 932 | + M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=42, nx=nx) |
934 | 933 |
|
935 | | - x_source = rng.randn(n_source, 2) |
936 | | - x_target = rng.randn(n_target, 2) + 0.5 |
| 934 | + a = ot.utils.unif(n1, type_as=M_dense) |
| 935 | + b = ot.utils.unif(n2, type_as=M_dense) |
937 | 936 |
|
938 | | - a = ot.utils.unif(n_source) |
939 | | - b = ot.utils.unif(n_target) |
940 | | - |
941 | | - C = ot.dist(x_source, x_target) |
942 | | - |
943 | | - # Compute k-NN sparse cost matrix |
944 | | - C_knn = ot.utils.dist_knn(x_source, x_target, k=k, metric="sqeuclidean") |
945 | | - |
946 | | - # First pass: solve with k-NN to identify active edges |
947 | | - large_cost = 1e8 |
948 | | - C_dense_infty = np.full((n_source, n_target), large_cost) |
949 | | - C_knn_array = C_knn.toarray() |
950 | | - C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] |
951 | | - |
952 | | - G_dense_initial = ot.emd(a, b, C_dense_infty) |
953 | | - eps = 1e-9 |
954 | | - active_mask = G_dense_initial > eps |
955 | | - knn_mask = C_knn_array > 0 |
956 | | - extra_edges_mask = active_mask & ~knn_mask |
957 | | - |
958 | | - rows_aug = [] |
959 | | - cols_aug = [] |
960 | | - data_aug = [] |
961 | | - |
962 | | - knn_rows, knn_cols = np.where(knn_mask) |
963 | | - for i, j in zip(knn_rows, knn_cols): |
964 | | - rows_aug.append(i) |
965 | | - cols_aug.append(j) |
966 | | - data_aug.append(C[i, j]) |
967 | | - |
968 | | - extra_rows, extra_cols = np.where(extra_edges_mask) |
969 | | - for i, j in zip(extra_rows, extra_cols): |
970 | | - rows_aug.append(i) |
971 | | - cols_aug.append(j) |
972 | | - data_aug.append(C[i, j]) |
973 | | - |
974 | | - C_augmented = coo_array( |
975 | | - (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) |
976 | | - ) |
977 | | - |
978 | | - C_augmented_dense = np.full((n_source, n_target), large_cost) |
979 | | - C_augmented_dense[rows_aug, cols_aug] = data_aug |
980 | | - |
981 | | - G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) |
982 | | - G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True) |
| 937 | + # Solve with both dense and sparse solvers |
| 938 | + G_dense, log_dense = ot.emd(a, b, M_dense, log=True) |
| 939 | + G_sparse, log_sparse = ot.emd(a, b, M_sparse, log=True) |
983 | 940 |
|
984 | 941 | cost_dense = log_dense["cost"] |
985 | 942 | cost_sparse = log_sparse["cost"] |
986 | | - |
987 | 943 | np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) |
988 | 944 |
|
989 | | - # For dense, G_dense is returned; for sparse, reconstruct from flow edges |
990 | 945 | np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7) |
991 | 946 | np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) |
992 | 947 |
|
993 | | - # G_sparse is now returned as a sparse matrix |
994 | | - from scipy.sparse import issparse |
| 948 | + assert nx.issparse(G_sparse), "Sparse solver should return a sparse matrix" |
995 | 949 |
|
996 | | - assert issparse(G_sparse), "Sparse solver should return a sparse matrix" |
997 | | - |
998 | | - # Convert to dense for marginal checks |
999 | | - G_sparse_dense = G_sparse.toarray() |
1000 | | - np.testing.assert_allclose(a, G_sparse_dense.sum(1), rtol=1e-5, atol=1e-7) |
1001 | | - np.testing.assert_allclose(b, G_sparse_dense.sum(0), rtol=1e-5, atol=1e-7) |
| 950 | + G_sparse_dense = nx.todense(G_sparse) |
| 951 | + np.testing.assert_allclose( |
| 952 | + a, nx.to_numpy(nx.sum(G_sparse_dense, 1)), rtol=1e-5, atol=1e-7 |
| 953 | + ) |
| 954 | + np.testing.assert_allclose( |
| 955 | + b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7 |
| 956 | + ) |
1002 | 957 |
|
1003 | 958 |
|
1004 | 959 | def test_emd2_sparse_vs_dense(nx): |
1005 | | - """Test that sparse and dense emd2 solvers produce identical results. |
| 960 | + """Test that sparse and dense emd2 solvers produce identical costs. |
1006 | 961 |
|
1007 | | - Uses augmented k-NN graph approach: first solves with dense solver to |
1008 | | - identify needed edges, then compares both solvers on the same graph. |
| 962 | + Uses random sparse graphs with k=2 edges per row/column, which guarantees |
| 963 | + feasibility with uniform marginals. |
1009 | 964 | """ |
1010 | 965 | # Skip for backends that don't support sparse matrices |
1011 | 966 | backend_name = nx.__class__.__name__.lower() |
1012 | 967 | if "jax" in backend_name or "tensorflow" in backend_name: |
1013 | 968 | pytest.skip("Backend does not support sparse matrices") |
1014 | 969 |
|
1015 | | - n_source = 100 |
1016 | | - n_target = 100 |
1017 | | - k = 10 |
1018 | | - |
1019 | | - rng = np.random.RandomState(42) |
1020 | | - |
1021 | | - x_source = rng.randn(n_source, 2) |
1022 | | - x_target = rng.randn(n_target, 2) + 0.5 |
1023 | | - |
1024 | | - a = ot.utils.unif(n_source) |
1025 | | - b = ot.utils.unif(n_target) |
1026 | | - |
1027 | | - C = ot.dist(x_source, x_target) |
1028 | | - |
1029 | | - # Compute k-NN sparse cost matrix |
1030 | | - C_knn = ot.utils.dist_knn(x_source, x_target, k=k, metric="sqeuclidean") |
1031 | | - |
1032 | | - # First pass: solve with k-NN to identify active edges |
1033 | | - large_cost = 1e8 |
1034 | | - C_dense_infty = np.full((n_source, n_target), large_cost) |
1035 | | - C_knn_array = C_knn.toarray() |
1036 | | - C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] |
| 970 | + n1 = 100 |
| 971 | + n2 = 150 |
| 972 | + k = 2 |
1037 | 973 |
|
1038 | | - G_dense_initial = ot.emd(a, b, C_dense_infty) |
| 974 | + M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=43, nx=nx) |
1039 | 975 |
|
1040 | | - eps = 1e-9 |
1041 | | - active_mask = G_dense_initial > eps |
1042 | | - knn_mask = C_knn_array > 0 |
1043 | | - extra_edges_mask = active_mask & ~knn_mask |
| 976 | + a = ot.utils.unif(n1, type_as=M_dense) |
| 977 | + b = ot.utils.unif(n2, type_as=M_dense) |
1044 | 978 |
|
1045 | | - rows_aug = [] |
1046 | | - cols_aug = [] |
1047 | | - data_aug = [] |
1048 | | - |
1049 | | - knn_rows, knn_cols = np.where(knn_mask) |
1050 | | - for i, j in zip(knn_rows, knn_cols): |
1051 | | - rows_aug.append(i) |
1052 | | - cols_aug.append(j) |
1053 | | - data_aug.append(C[i, j]) |
1054 | | - |
1055 | | - extra_rows, extra_cols = np.where(extra_edges_mask) |
1056 | | - for i, j in zip(extra_rows, extra_cols): |
1057 | | - rows_aug.append(i) |
1058 | | - cols_aug.append(j) |
1059 | | - data_aug.append(C[i, j]) |
1060 | | - |
1061 | | - C_augmented = coo_array( |
1062 | | - (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) |
1063 | | - ) |
1064 | | - |
1065 | | - C_augmented_dense = np.full((n_source, n_target), large_cost) |
1066 | | - C_augmented_dense[rows_aug, cols_aug] = data_aug |
1067 | | - |
1068 | | - cost_dense = ot.emd2(a, b, C_augmented_dense) |
1069 | | - cost_sparse = ot.emd2(a, b, C_augmented) |
| 979 | + # Solve with both dense and sparse solvers |
| 980 | + cost_dense = ot.emd2(a, b, M_dense) |
| 981 | + cost_sparse = ot.emd2(a, b, M_sparse) |
1070 | 982 |
|
| 983 | + # Check costs match |
1071 | 984 | np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) |
1072 | 985 |
|
1073 | 986 |
|
|
0 commit comments