Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions exla/test/exla/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,23 @@ defmodule EXLA.NxRandomTest do
)
end
end

@tag :cuda_required
test "regression on single-dimensional and multi-dimensional Random.shuffle" do
# these are put in the process dictionary, so it's thread-safe to do this
Nx.default_backend({EXLA.Backend, client: :cuda})
Nx.Defn.default_options(compiler: EXLA, client: :cuda)
key = Nx.Random.key(127)

t1 = Nx.iota({2, 100})
t2 = Nx.iota({100})

{t1_shuffled_0, key} = Nx.Random.shuffle(key, t1, axis: 0)
{t1_shuffled_1, key} = Nx.Random.shuffle(key, t1, axis: 1)
{t2_shuffled, _key} = Nx.Random.shuffle(key, t2)

assert_equal(Nx.sort(t1_shuffled_0, axis: 0), t1)
assert_equal(Nx.sort(t1_shuffled_1, axis: 1), t1)
assert_equal(Nx.sort(t2_shuffled), t2)
end
end
25 changes: 20 additions & 5 deletions nx/lib/nx/random.ex
Original file line number Diff line number Diff line change
Expand Up @@ -799,15 +799,21 @@ defmodule Nx.Random do
axis = opts[:axis]

if opts[:independent] do
shuffle_independent(key, tensor, axis: axis)
shuffle_independent(key, tensor, axis: axis, independent: true)
else
{idx, key} = shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}), axis: 0)
{idx, key} =
shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}),
axis: 0,
independent: false
)

{Nx.take(tensor, idx, axis: axis), key}
end
end

defnp shuffle_independent(key, tensor, opts) do
axis = opts[:axis]
independent = opts[:independent]

# reference: https://github.com/google/jax/blob/838bc454895ed2086563301936fb0d6d852fd198/jax/_src/random.py#L437
exponent = 3
Expand All @@ -821,16 +827,25 @@ defmodule Nx.Random do
while {i = 0, tensor, key}, i < num_rounds do
keys = split(key)
sort_keys = random_bits(keys[1], shape: tensor.shape)
tensor = sort_key_val(tensor, sort_keys, axis: axis)
tensor = sort_key_val(tensor, sort_keys, axis: axis, independent: independent)
{i + 1, tensor, keys[0]}
end

{out, key}
end

defnp sort_key_val(tensor, sort_keys, opts \\ []) do
deftransformp sort_key_val(tensor, sort_keys, opts \\ []) do
idx = Nx.argsort(sort_keys, axis: opts[:axis])
Nx.take_along_axis(tensor, idx, axis: opts[:axis])

if opts[:independent] do
# We need to use take_along_axis in the independent case because
# the sort_keys tensor has the same shape as the input tensor.
Nx.take_along_axis(tensor, idx, axis: opts[:axis])
else
# In the non-independent case, we use take because the sort_keys
# tensor is a 1D tensor.
Nx.take(tensor, idx, axis: opts[:axis])
end
end

@choice_options """
Expand Down
Loading