From 9278f344b0045f5e93cb8a7391089bda4b632af8 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 5 May 2025 17:22:50 -0300 Subject: [PATCH 1/2] feat: adds testing.ex in lib --- nx/lib/nx/testing.ex | 79 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 nx/lib/nx/testing.ex diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex new file mode 100644 index 0000000000..31e8202b16 --- /dev/null +++ b/nx/lib/nx/testing.ex @@ -0,0 +1,79 @@ +defmodule Nx.Testing do + @moduledoc """ + Testing functions for Nx tensor assertions. + + This module provides functions for asserting tensor equality and + approximate equality within specified tolerances. + """ + + @doc """ + Asserts that two tensors are exactly equal. + + This handles NaN values correctly by considering NaN == NaN as true. + """ + def assert_equal(left, right) do + both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) + + equals = + left + |> Nx.equal(right) + |> Nx.logical_or(both_nan) + |> Nx.all() + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + + if !equals do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + @doc """ + Asserts that two tensors are approximately equal within the given tolerances. + + See also: + + * `Nx.all_close/2` - The underlying function that performs the comparison. + + ## Options + + * `:atol` - The absolute tolerance. Defaults to 1.0e-4. + * `:rtol` - The relative tolerance. Defaults to 1.0e-4. + """ + def assert_all_close(left, right, opts \\ []) do + atol = opts[:atol] || 1.0e-4 + rtol = opts[:rtol] || 1.0e-4 + + equals = + left + |> Nx.all_close(right, atol: atol, rtol: rtol) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + + if !equals do + flunk(""" + expected + + #{inspect(left)} + + to be within tolerance of + + #{inspect(right)} + """) + end + end + + @doc """ + Converts a tensor to the binary backend. + + This is useful for comparing tensors in assertions, as it ensures + consistent representation regardless of the original backend. + """ + def to_binary_backend(tensor) do + Nx.backend_copy(tensor, Nx.BinaryBackend) + end +end From a72980f85fcde6e5c8ff4cfa0d0d6cd7529ff5b3 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 5 May 2025 17:59:47 -0300 Subject: [PATCH 2/2] add: uses NxTesting on exla_case and torchx_case --- exla/test/support/exla_case.ex | 30 +------------------ nx/lib/nx/testing.ex | 46 ++++++++++++++++-------------- torchx/test/support/torchx_case.ex | 39 +------------------------ 3 files changed, 27 insertions(+), 88 deletions(-) diff --git a/exla/test/support/exla_case.ex b/exla/test/support/exla_case.ex index fa12092575..0be7869954 100644 --- a/exla/test/support/exla_case.ex +++ b/exla/test/support/exla_case.ex @@ -8,13 +8,7 @@ defmodule EXLA.Case do using do quote do import EXLA.Case - end - end - - defmacro assert_equal(left, right) do - # Assert against binary backend tensors to show diff on failure - quote do - assert unquote(left) |> to_binary_backend() == unquote(right) |> to_binary_backend() + import Nx.Testing end end @@ -22,28 +16,6 @@ defmodule EXLA.Case do Nx.backend_copy(tensor, Nx.BinaryBackend) end - def assert_all_close(left, right, opts \\ []) do - atol = opts[:atol] || 1.0e-4 - rtol = opts[:rtol] || 1.0e-4 - - equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - - if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do - flunk(""" - expected - - #{inspect(left)} - - to be within tolerance of - - #{inspect(right)} - """) - end - end - def is_mac_arm? do Application.fetch_env!(:exla, :is_mac_arm) end diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex index 31e8202b16..ad2d5ee482 100644 --- a/nx/lib/nx/testing.ex +++ b/nx/lib/nx/testing.ex @@ -6,23 +6,26 @@ defmodule Nx.Testing do approximate equality within specified tolerances. """ + import ExUnit.Assertions + import Nx, only: [is_tensor: 1] + @doc """ Asserts that two tensors are exactly equal. This handles NaN values correctly by considering NaN == NaN as true. """ - def assert_equal(left, right) do - both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) - - equals = - left - |> Nx.equal(right) - |> Nx.logical_or(both_nan) - |> Nx.all() - |> Nx.to_flat_list() - |> Enum.all?(&(&1 == 1)) + def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do + if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end - if !equals do + def assert_equal(left, right) do + if !tensor_equal?(left, right) do flunk(""" Tensor assertion failed. left: #{inspect(left)} @@ -31,6 +34,17 @@ defmodule Nx.Testing do end end + def tensor_equal?(left, right) do + both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) + + left + |> Nx.equal(right) + |> Nx.logical_or(both_nan) + |> Nx.all() + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + end + @doc """ Asserts that two tensors are approximately equal within the given tolerances. @@ -66,14 +80,4 @@ defmodule Nx.Testing do """) end end - - @doc """ - Converts a tensor to the binary backend. - - This is useful for comparing tensors in assertions, as it ensures - consistent representation regardless of the original backend. - """ - def to_binary_backend(tensor) do - Nx.backend_copy(tensor, Nx.BinaryBackend) - end end diff --git a/torchx/test/support/torchx_case.ex b/torchx/test/support/torchx_case.ex index 7056343964..3b70915159 100644 --- a/torchx/test/support/torchx_case.ex +++ b/torchx/test/support/torchx_case.ex @@ -7,44 +7,7 @@ defmodule Torchx.Case do using do quote do - import Torchx.Case - end - end - - def assert_all_close(left, right, opts \\ []) do - atol = opts[:atol] || 1.0e-4 - rtol = opts[:rtol] || 1.0e-4 - - equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - - if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do - flunk(""" - Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} - """) - end - end - - def assert_equal(left, right) do - both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) - - equals = - left - |> Nx.equal(right) - |> Nx.logical_or(both_nan) - |> Nx.all() - |> Nx.to_number() - - if equals != 1 do - flunk(""" - Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} - """) + import Nx.Testing end end end