Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 1 addition & 29 deletions exla/test/support/exla_case.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,14 @@ defmodule EXLA.Case do
using do
quote do
import EXLA.Case
end
end

defmacro assert_equal(left, right) do
# Assert against binary backend tensors to show diff on failure
quote do
assert unquote(left) |> to_binary_backend() == unquote(right) |> to_binary_backend()
import Nx.Testing
end
end

def to_binary_backend(tensor) do
Nx.backend_copy(tensor, Nx.BinaryBackend)
end

def assert_all_close(left, right, opts \\ []) do
atol = opts[:atol] || 1.0e-4
rtol = opts[:rtol] || 1.0e-4

equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)

if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do
flunk("""
expected

#{inspect(left)}

to be within tolerance of

#{inspect(right)}
""")
end
end

def is_mac_arm? do
Application.fetch_env!(:exla, :is_mac_arm)
end
Expand Down
83 changes: 83 additions & 0 deletions nx/lib/nx/testing.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
defmodule Nx.Testing do
@moduledoc """
Testing functions for Nx tensor assertions.

This module provides functions for asserting tensor equality and
approximate equality within specified tolerances.
"""

import ExUnit.Assertions
import Nx, only: [is_tensor: 1]

@doc """
Asserts that two tensors are exactly equal.

This handles NaN values correctly by considering NaN == NaN as true.
"""
def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do
if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do
flunk("""
Tensor assertion failed.
left: #{inspect(left)}
right: #{inspect(right)}
""")
end
end

def assert_equal(left, right) do
if !tensor_equal?(left, right) do
flunk("""
Tensor assertion failed.
left: #{inspect(left)}
right: #{inspect(right)}
""")
end
end

def tensor_equal?(left, right) do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be private?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It was originally part of one of the assert_equal clauses but we wanted to reuse it in both.

Forgot to set it as defp :p

both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))

left
|> Nx.equal(right)
|> Nx.logical_or(both_nan)
|> Nx.all()
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))
end

@doc """
Asserts that two tensors are approximately equal within the given tolerances.

See also:

* `Nx.all_close/2` - The underlying function that performs the comparison.

## Options

* `:atol` - The absolute tolerance. Defaults to 1.0e-4.
* `:rtol` - The relative tolerance. Defaults to 1.0e-4.
"""
def assert_all_close(left, right, opts \\ []) do
atol = opts[:atol] || 1.0e-4
rtol = opts[:rtol] || 1.0e-4

equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))

if !equals do
flunk("""
expected

#{inspect(left)}

to be within tolerance of

#{inspect(right)}
""")
end
end
end
39 changes: 1 addition & 38 deletions torchx/test/support/torchx_case.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,7 @@ defmodule Torchx.Case do

using do
quote do
import Torchx.Case
end
end

def assert_all_close(left, right, opts \\ []) do
atol = opts[:atol] || 1.0e-4
rtol = opts[:rtol] || 1.0e-4

equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)

if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do
flunk("""
Tensor assertion failed.
left: #{inspect(left)}
right: #{inspect(right)}
""")
end
end

def assert_equal(left, right) do
both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))

equals =
left
|> Nx.equal(right)
|> Nx.logical_or(both_nan)
|> Nx.all()
|> Nx.to_number()

if equals != 1 do
flunk("""
Tensor assertion failed.
left: #{inspect(left)}
right: #{inspect(right)}
""")
import Nx.Testing
end
end
end