Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,16 @@ defmodule EXLA.MLIR.Value do
comparison_type =
cond do
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
attr_comparison_type(:float)
[compare_type: attr_comparison_type(:float)]

Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
attr_comparison_type(:float)
[compare_type: attr_comparison_type(:float)]

true ->
attr_comparison_type(:notype)
[]
end

attributes = [
comparison_direction: attr_comparison_direction(direction),
compare_type: comparison_type
]
attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type

result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])

Expand Down
19 changes: 19 additions & 0 deletions exla/test/exla/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,23 @@ defmodule EXLA.NxRandomTest do
)
end
end

@tag :cuda_required
test "regression on single-dimensional and multi-dimensional Random.shuffle" do
# these are put in the process dictionary, so it's thread-safe to do this
Nx.default_backend({EXLA.Backend, client: :cuda})
Nx.Defn.default_options(compiler: EXLA, client: :cuda)
key = Nx.Random.key(127)

t1 = Nx.iota({2, 100})
t2 = Nx.iota({100})

{t1_shuffled_0, key} = Nx.Random.shuffle(key, t1, axis: 0)
{t1_shuffled_1, key} = Nx.Random.shuffle(key, t1, axis: 1)
{t2_shuffled, _key} = Nx.Random.shuffle(key, t2)

assert_equal(Nx.sort(t1_shuffled_0, axis: 0), t1)
assert_equal(Nx.sort(t1_shuffled_1, axis: 1), t1)
assert_equal(Nx.sort(t2_shuffled), t2)
end
end
Loading