diff --git a/exla/test/support/exla_case.ex b/exla/test/support/exla_case.ex index fa12092575..0be7869954 100644 --- a/exla/test/support/exla_case.ex +++ b/exla/test/support/exla_case.ex @@ -8,13 +8,7 @@ defmodule EXLA.Case do using do quote do import EXLA.Case - end - end - - defmacro assert_equal(left, right) do - # Assert against binary backend tensors to show diff on failure - quote do - assert unquote(left) |> to_binary_backend() == unquote(right) |> to_binary_backend() + import Nx.Testing end end @@ -22,28 +16,6 @@ defmodule EXLA.Case do Nx.backend_copy(tensor, Nx.BinaryBackend) end - def assert_all_close(left, right, opts \\ []) do - atol = opts[:atol] || 1.0e-4 - rtol = opts[:rtol] || 1.0e-4 - - equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - - if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do - flunk(""" - expected - - #{inspect(left)} - - to be within tolerance of - - #{inspect(right)} - """) - end - end - def is_mac_arm? do Application.fetch_env!(:exla, :is_mac_arm) end diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex new file mode 100644 index 0000000000..ad2d5ee482 --- /dev/null +++ b/nx/lib/nx/testing.ex @@ -0,0 +1,83 @@ +defmodule Nx.Testing do + @moduledoc """ + Testing functions for Nx tensor assertions. + + This module provides functions for asserting tensor equality and + approximate equality within specified tolerances. + """ + + import ExUnit.Assertions + import Nx, only: [is_tensor: 1] + + @doc """ + Asserts that two tensors are exactly equal. + + This handles NaN values correctly by considering NaN == NaN as true. + """ + def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do + if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def assert_equal(left, right) do + if !tensor_equal?(left, right) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def tensor_equal?(left, right) do + both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) + + left + |> Nx.equal(right) + |> Nx.logical_or(both_nan) + |> Nx.all() + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + end + + @doc """ + Asserts that two tensors are approximately equal within the given tolerances. + + See also: + + * `Nx.all_close/2` - The underlying function that performs the comparison. + + ## Options + + * `:atol` - The absolute tolerance. Defaults to 1.0e-4. + * `:rtol` - The relative tolerance. Defaults to 1.0e-4. + """ + def assert_all_close(left, right, opts \\ []) do + atol = opts[:atol] || 1.0e-4 + rtol = opts[:rtol] || 1.0e-4 + + equals = + left + |> Nx.all_close(right, atol: atol, rtol: rtol) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + + if !equals do + flunk(""" + expected + + #{inspect(left)} + + to be within tolerance of + + #{inspect(right)} + """) + end + end +end diff --git a/torchx/test/support/torchx_case.ex b/torchx/test/support/torchx_case.ex index 7056343964..3b70915159 100644 --- a/torchx/test/support/torchx_case.ex +++ b/torchx/test/support/torchx_case.ex @@ -7,44 +7,7 @@ defmodule Torchx.Case do using do quote do - import Torchx.Case - end - end - - def assert_all_close(left, right, opts \\ []) do - atol = opts[:atol] || 1.0e-4 - rtol = opts[:rtol] || 1.0e-4 - - equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - - if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do - flunk(""" - Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} - """) - end - end - - def assert_equal(left, right) do - both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) - - equals = - left - |> Nx.equal(right) - |> Nx.logical_or(both_nan) - |> Nx.all() - |> Nx.to_number() - - if equals != 1 do - flunk(""" - Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} - """) + import Nx.Testing end end end