Skip to content

Commit fc51528

Browse files
committed
update sliced plans with sparse matrix for tf compatibility
2 parents 282ac99 + 652b49d commit fc51528

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#### New features
66

7-
- Added Sliced OT plans (PR #757)
7+
- Added Sliced OT plans (PR #767)
88

99
## 0.9.6.post1
1010

ot/sliced.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Nicolas Courty <ncourty@irisa.fr>
77
# Rémi Flamary <remi.flamary@polytechnique.edu>
88
# Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
9+
# Laetitia Chapel <laetitia.chapel@irisa.fr>
910
#
1011
# License: MIT License
1112

@@ -820,21 +821,23 @@ def sliced_plans(
820821
for k in range(n_proj)
821822
]
822823

824+
if not dense and str(nx) == "jax":
825+
warnings.warn("JAX does not support sparse matrices, converting to dense")
826+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
827+
823828
else: # we compute plans
824829
_, plan = wasserstein_1d(
825830
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
826831
)
827832

828-
if str(nx) == "tensorflow": # tf does not support duplicate entries
829-
plan = [plan[k].tocsr().tocoo() for k in range(n_proj)]
830-
831-
if str(nx) == "jax":
832-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
833+
if str(nx) == "jax": # dense computation for jax
833834
if not dense:
834835
warnings.warn(
835-
"JAX does not support sparse matrices, converting" "to dense"
836+
"JAX does not support sparse matrices, converting to dense"
836837
)
837838

839+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
840+
838841
costs = [
839842
nx.sum(
840843
(
@@ -854,7 +857,11 @@ def sliced_plans(
854857
)
855858
for k in range(n_proj)
856859
]
860+
857861
else:
862+
if str(nx) == "tensorflow": # tf does not support multiple indexing
863+
plan = [plan[k].tocsr().tocoo() for k in range(n_proj)]
864+
858865
if metric in ("minkowski", "euclidean", "cityblock"):
859866
costs = [
860867
nx.sum(
@@ -870,7 +877,7 @@ def sliced_plans(
870877
)
871878
for k in range(n_proj)
872879
]
873-
else: # metric = "sqeuclidean"
880+
else: # metric == "sqeuclidean"
874881
costs = [
875882
nx.sum(
876883
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
@@ -879,7 +886,7 @@ def sliced_plans(
879886
for k in range(n_proj)
880887
]
881888

882-
if dense:
889+
if dense and not str(nx) == "jax":
883890
plan = [nx.todense(plan[k]) for k in range(n_proj)]
884891

885892
if log:

test/test_sliced.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
44
# Nicolas Courty <ncourty@irisa.fr>
55
# Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
6+
# Laetitia Chapel <laetitia.chapel@irisa.fr>
67
#
78
# License: MIT License
89

@@ -943,7 +944,7 @@ def test_sliced_plans_backends(nx):
943944
x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b)
944945

945946
thetas_b = ot.sliced.get_random_projections(
946-
d, n_proj, seed=0, backend=nx, type_as=x
947+
d, n_proj, seed=0, backend=nx, type_as=x_b
947948
).T
948949
thetas = nx.to_numpy(thetas_b)
949950

0 commit comments

Comments
 (0)