Skip to content

Commit a19753e

Browse files
committed
update sliced plans with sparse matrix for tf compatibility
1 parent fc51528 commit a19753e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ot/sliced.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,21 +873,24 @@ def sliced_plans(
873873
)
874874
** (1 / p)
875875
)
876-
* plan[k].data
876+
* plan[k]
877877
)
878878
for k in range(n_proj)
879879
]
880880
else: # metric == "sqeuclidean"
881881
costs = [
882882
nx.sum(
883883
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
884-
* plan[k].data
884+
* plan[k]
885885
)
886886
for k in range(n_proj)
887887
]
888888

889889
if dense and not str(nx) == "jax":
890890
plan = [nx.todense(plan[k]) for k in range(n_proj)]
891+
elif str(nx) == "jax":
892+
warnings.warn("JAX does not support sparse matrices, converting to dense")
893+
plan = [nx.todense(plan[k]) for k in range(n_proj)]
891894

892895
if log:
893896
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}

0 commit comments

Comments
 (0)