Skip to content

Commit 652b49d

Browse files
committed
jax .data fix + backend typing fix in sliced_plans test function
1 parent dd1b31f commit 652b49d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ot/sliced.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,15 +880,15 @@ def sliced_plans(
880880
)
881881
** (1 / p)
882882
)
883-
* plan[k].data
883+
* plan[k]
884884
)
885885
for k in range(n_proj)
886886
]
887887
elif metric == "sqeuclidean":
888888
costs = [
889889
nx.sum(
890890
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
891-
* plan[k].data
891+
* plan[k]
892892
)
893893
for k in range(n_proj)
894894
]

test/test_sliced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,9 @@ def test_sliced_plans_backends(nx):
925925

926926
x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b)
927927

928-
thetas_b = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T
928+
thetas_b = ot.sliced.get_random_projections(
929+
d, n_proj, seed=0, backend=nx, type_as=x_b
930+
).T
929931
thetas = nx.to_numpy(thetas_b)
930932

931933
context = (

0 commit comments

Comments
 (0)