Skip to content

Commit 46c233c

Browse files
authored
fix(exla): vectorized gather (#1595)
1 parent 592e0a1 commit 46c233c

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

exla/lib/exla/defn.ex

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,8 +1141,7 @@ defmodule EXLA.Defn do
11411141
end
11421142

11431143
batch_size = tensor_rank - length(axes)
1144-
offset_size = indices_rank - length(axes)
1145-
offset_dims = count_up(batch_size, offset_size)
1144+
offset_dims = count_up(batch_size, index_vector_dim)
11461145

11471146
Value.gather(
11481147
tensor,

exla/test/exla/backend_test.exs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,42 @@ defmodule EXLA.BackendTest do
225225
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
226226
end
227227

228+
test "gather vectorized regression" do
229+
gradients =
230+
Nx.tensor(
231+
[
232+
[1.0, 1.0],
233+
[-1.0, 1.0],
234+
[1.0, -1.0],
235+
[-1.0, -1.0]
236+
],
237+
backend: EXLA.Backend
238+
)
239+
240+
i =
241+
Nx.tensor([[0, 2, 3, 2, 2, 2, 2, 1]], type: {:u, 16}, backend: EXLA.Backend)
242+
|> Nx.vectorize([:x, :octaves])
243+
244+
result = Nx.gather(gradients, Nx.reshape(i, {1}))
245+
246+
assert_equal(
247+
result,
248+
Nx.tensor([
249+
[
250+
[1.0, 1.0],
251+
[1.0, -1.0],
252+
[-1.0, -1.0],
253+
[1.0, -1.0],
254+
[1.0, -1.0],
255+
[1.0, -1.0],
256+
[1.0, -1.0],
257+
[-1.0, 1.0]
258+
]
259+
])
260+
|> Nx.vectorize([:x, :octaves])
261+
)
262+
end
263+
228264
describe "quantized types" do
229265
test "s2" do
230266
tensor = Nx.s2(-1)

0 commit comments

Comments
 (0)