Skip to content

Commit a298a86

Browse files
committed
Simplified test comparing sparse to dense to be more approchable with proper backend adaptation
1 parent 4885368 commit a298a86

File tree

2 files changed

+61
-187
lines changed

2 files changed

+61
-187
lines changed

ot/utils.py

Lines changed: 29 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import numpy as np
1414
from scipy.spatial.distance import cdist
1515
from scipy.sparse import coo_array
16-
from sklearn.neighbors import NearestNeighbors
1716
import sys
1817
import warnings
1918
from inspect import signature
@@ -436,75 +435,37 @@ def dist(
436435
return cdist(x1, x2, metric=metric)
437436

438437

439-
def dist_knn(
440-
x1,
441-
x2=None,
442-
k=10,
443-
metric="euclidean",
444-
p=2,
445-
):
446-
r"""Compute sparse k-nearest neighbors distance matrix in COO format
447-
448-
This function efficiently computes a sparse distance matrix containing only
449-
the k-nearest neighbors for each sample, which is useful for large-scale
450-
optimal transport problems where the full dense distance matrix would be
451-
prohibitively large.
452-
453-
Parameters
454-
----------
455-
x1 : array-like, shape (n1, d)
456-
Matrix with `n1` samples of size `d`
457-
x2 : array-like, shape (n2, d), optional
458-
Matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
459-
k : int, optional (default=10)
460-
Number of nearest neighbors to keep for each sample
461-
metric : str, optional (default='euclidean')
462-
Distance metric to use. Supported metrics include: 'euclidean', 'manhattan',
463-
'chebyshev', 'minkowski', 'cityblock', 'cosine', 'l1', 'l2', 'sqeuclidean',
464-
and others supported by sklearn.neighbors.NearestNeighbors
465-
p : float, optional (default=2)
466-
Parameter for the Minkowski metric
467-
468-
Returns
469-
-------
470-
M_sparse : scipy.sparse.coo_array, shape (n1, n2)
471-
Sparse distance matrix in COO format containing only k-nearest neighbors
472-
473-
"""
474-
nx = get_backend(x1, x2)
475-
476-
# Convert to numpy for k-NN computation
477-
x1_np = nx.to_numpy(x1)
478-
x2_np = nx.to_numpy(x2) if x2 is not None else x1_np
479-
480-
n1 = x1_np.shape[0]
481-
n2 = x2_np.shape[0]
482-
k_actual = min(k, n2) # Handle case where k > n2
483-
484-
# Use sklearn's efficient k-NN implementation
485-
metric_params = {}
486-
if metric == "minkowski":
487-
metric_params["p"] = p
488-
489-
nbrs = NearestNeighbors(
490-
n_neighbors=k_actual,
491-
algorithm="auto",
492-
metric=metric,
493-
metric_params=metric_params if metric_params else None,
494-
)
495-
nbrs.fit(x2_np)
496-
497-
# Find k-nearest neighbors and their distances
498-
distances, indices = nbrs.kneighbors(x1_np)
499-
500-
# Build sparse matrix in COO format
501-
rows = np.repeat(np.arange(n1), k_actual)
502-
cols = indices.ravel()
503-
data = distances.ravel()
438+
def get_sparse_test_matrices(n1, n2, k=2, seed=42, nx=None):
439+
if nx is None:
440+
nx = NumpyBackend()
441+
442+
rng = np.random.RandomState(seed)
443+
M_orig = rng.rand(n1, n2)
444+
445+
mask = np.zeros((n1, n2))
446+
for i in range(n1):
447+
j_list = rng.choice(n2, min(k, n2), replace=False)
448+
for j in j_list:
449+
mask[i, j] = 1
450+
for j in range(n2):
451+
i_list = rng.choice(n1, min(k, n1), replace=False)
452+
for i in i_list:
453+
mask[i, j] = 1
454+
455+
M_sparse_np = coo_array(M_orig * mask)
456+
rows, cols, data = M_sparse_np.row, M_sparse_np.col, M_sparse_np.data
457+
458+
if nx.__name__ == "numpy":
459+
M_sparse = M_sparse_np
460+
else:
461+
rows_b = nx.from_numpy(rows.astype(np.int64))
462+
cols_b = nx.from_numpy(cols.astype(np.int64))
463+
data_b = nx.from_numpy(data)
464+
M_sparse = nx.coo_matrix(data_b, rows_b, cols_b, shape=(n1, n2))
504465

505-
M_sparse = coo_array((data, (rows, cols)), shape=(n1, n2))
466+
M_dense = nx.from_numpy(M_orig + 1e8 * (1 - mask))
506467

507-
return M_sparse
468+
return M_sparse, M_dense
508469

509470

510471
def dist0(n, method="lin_square"):

test/test_ot.py

Lines changed: 32 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import ot
1313
from ot.datasets import make_1D_gauss as gauss
1414
from ot.backend import torch, tf, get_backend
15-
from scipy.sparse import coo_array
1615

1716

1817
def test_emd_dimension_and_mass_mismatch():
@@ -918,156 +917,70 @@ def test_dual_variables():
918917
def test_emd_sparse_vs_dense(nx):
919918
"""Test that sparse and dense EMD solvers produce identical results.
920919
921-
Uses augmented k-NN graph approach: first solves with dense solver to
922-
identify needed edges, then compares both solvers on the same graph.
920+
Uses random sparse graphs with k=2 edges per row/column, which guarantees
921+
feasibility with uniform marginals.
923922
"""
924923
# Skip for backends that don't support sparse matrices
925924
backend_name = nx.__class__.__name__.lower()
926925
if "jax" in backend_name or "tensorflow" in backend_name:
927926
pytest.skip("Backend does not support sparse matrices")
928927

929-
n_source = 100
930-
n_target = 100
931-
k = 10
928+
n1 = 100
929+
n2 = 100
930+
k = 2
932931

933-
rng = np.random.RandomState(42)
932+
M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=42, nx=nx)
934933

935-
x_source = rng.randn(n_source, 2)
936-
x_target = rng.randn(n_target, 2) + 0.5
934+
a = ot.utils.unif(n1, type_as=M_dense)
935+
b = ot.utils.unif(n2, type_as=M_dense)
937936

938-
a = ot.utils.unif(n_source)
939-
b = ot.utils.unif(n_target)
940-
941-
C = ot.dist(x_source, x_target)
942-
943-
# Compute k-NN sparse cost matrix
944-
C_knn = ot.utils.dist_knn(x_source, x_target, k=k, metric="sqeuclidean")
945-
946-
# First pass: solve with k-NN to identify active edges
947-
large_cost = 1e8
948-
C_dense_infty = np.full((n_source, n_target), large_cost)
949-
C_knn_array = C_knn.toarray()
950-
C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0]
951-
952-
G_dense_initial = ot.emd(a, b, C_dense_infty)
953-
eps = 1e-9
954-
active_mask = G_dense_initial > eps
955-
knn_mask = C_knn_array > 0
956-
extra_edges_mask = active_mask & ~knn_mask
957-
958-
rows_aug = []
959-
cols_aug = []
960-
data_aug = []
961-
962-
knn_rows, knn_cols = np.where(knn_mask)
963-
for i, j in zip(knn_rows, knn_cols):
964-
rows_aug.append(i)
965-
cols_aug.append(j)
966-
data_aug.append(C[i, j])
967-
968-
extra_rows, extra_cols = np.where(extra_edges_mask)
969-
for i, j in zip(extra_rows, extra_cols):
970-
rows_aug.append(i)
971-
cols_aug.append(j)
972-
data_aug.append(C[i, j])
973-
974-
C_augmented = coo_array(
975-
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
976-
)
977-
978-
C_augmented_dense = np.full((n_source, n_target), large_cost)
979-
C_augmented_dense[rows_aug, cols_aug] = data_aug
980-
981-
G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True)
982-
G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True)
937+
# Solve with both dense and sparse solvers
938+
G_dense, log_dense = ot.emd(a, b, M_dense, log=True)
939+
G_sparse, log_sparse = ot.emd(a, b, M_sparse, log=True)
983940

984941
cost_dense = log_dense["cost"]
985942
cost_sparse = log_sparse["cost"]
986-
987943
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
988944

989-
# For dense, G_dense is returned; for sparse, reconstruct from flow edges
990945
np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7)
991946
np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7)
992947

993-
# G_sparse is now returned as a sparse matrix
994-
from scipy.sparse import issparse
948+
assert nx.issparse(G_sparse), "Sparse solver should return a sparse matrix"
995949

996-
assert issparse(G_sparse), "Sparse solver should return a sparse matrix"
997-
998-
# Convert to dense for marginal checks
999-
G_sparse_dense = G_sparse.toarray()
1000-
np.testing.assert_allclose(a, G_sparse_dense.sum(1), rtol=1e-5, atol=1e-7)
1001-
np.testing.assert_allclose(b, G_sparse_dense.sum(0), rtol=1e-5, atol=1e-7)
950+
G_sparse_dense = nx.todense(G_sparse)
951+
np.testing.assert_allclose(
952+
a, nx.to_numpy(nx.sum(G_sparse_dense, 1)), rtol=1e-5, atol=1e-7
953+
)
954+
np.testing.assert_allclose(
955+
b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7
956+
)
1002957

1003958

1004959
def test_emd2_sparse_vs_dense(nx):
1005-
"""Test that sparse and dense emd2 solvers produce identical results.
960+
"""Test that sparse and dense emd2 solvers produce identical costs.
1006961
1007-
Uses augmented k-NN graph approach: first solves with dense solver to
1008-
identify needed edges, then compares both solvers on the same graph.
962+
Uses random sparse graphs with k=2 edges per row/column, which guarantees
963+
feasibility with uniform marginals.
1009964
"""
1010965
# Skip for backends that don't support sparse matrices
1011966
backend_name = nx.__class__.__name__.lower()
1012967
if "jax" in backend_name or "tensorflow" in backend_name:
1013968
pytest.skip("Backend does not support sparse matrices")
1014969

1015-
n_source = 100
1016-
n_target = 100
1017-
k = 10
1018-
1019-
rng = np.random.RandomState(42)
1020-
1021-
x_source = rng.randn(n_source, 2)
1022-
x_target = rng.randn(n_target, 2) + 0.5
1023-
1024-
a = ot.utils.unif(n_source)
1025-
b = ot.utils.unif(n_target)
1026-
1027-
C = ot.dist(x_source, x_target)
1028-
1029-
# Compute k-NN sparse cost matrix
1030-
C_knn = ot.utils.dist_knn(x_source, x_target, k=k, metric="sqeuclidean")
1031-
1032-
# First pass: solve with k-NN to identify active edges
1033-
large_cost = 1e8
1034-
C_dense_infty = np.full((n_source, n_target), large_cost)
1035-
C_knn_array = C_knn.toarray()
1036-
C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0]
970+
n1 = 100
971+
n2 = 150
972+
k = 2
1037973

1038-
G_dense_initial = ot.emd(a, b, C_dense_infty)
974+
M_sparse, M_dense = ot.utils.get_sparse_test_matrices(n1, n2, k=k, seed=43, nx=nx)
1039975

1040-
eps = 1e-9
1041-
active_mask = G_dense_initial > eps
1042-
knn_mask = C_knn_array > 0
1043-
extra_edges_mask = active_mask & ~knn_mask
976+
a = ot.utils.unif(n1, type_as=M_dense)
977+
b = ot.utils.unif(n2, type_as=M_dense)
1044978

1045-
rows_aug = []
1046-
cols_aug = []
1047-
data_aug = []
1048-
1049-
knn_rows, knn_cols = np.where(knn_mask)
1050-
for i, j in zip(knn_rows, knn_cols):
1051-
rows_aug.append(i)
1052-
cols_aug.append(j)
1053-
data_aug.append(C[i, j])
1054-
1055-
extra_rows, extra_cols = np.where(extra_edges_mask)
1056-
for i, j in zip(extra_rows, extra_cols):
1057-
rows_aug.append(i)
1058-
cols_aug.append(j)
1059-
data_aug.append(C[i, j])
1060-
1061-
C_augmented = coo_array(
1062-
(data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)
1063-
)
1064-
1065-
C_augmented_dense = np.full((n_source, n_target), large_cost)
1066-
C_augmented_dense[rows_aug, cols_aug] = data_aug
1067-
1068-
cost_dense = ot.emd2(a, b, C_augmented_dense)
1069-
cost_sparse = ot.emd2(a, b, C_augmented)
979+
# Solve with both dense and sparse solvers
980+
cost_dense = ot.emd2(a, b, M_dense)
981+
cost_sparse = ot.emd2(a, b, M_sparse)
1070982

983+
# Check costs match
1071984
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
1072985

1073986

0 commit comments

Comments
 (0)