Skip to content
255 changes: 255 additions & 0 deletions nx/guides/advanced/backend_comparison.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# Backend Comparison with Evaluator

```elixir
Mix.install([
# {:nx, "~> 0.7"}
{:nx, path: Path.join(__DIR__, "../..")},
{:mimic, "~> 1.7"}
])
```

## Introduction

This guide demonstrates how to use `Nx.Defn.Evaluator` to compare the outputs of different backends. This is particularly useful for:

* **Testing backend implementations** - Ensure different backends produce consistent results
* **Debugging numerical differences** - Identify where backends diverge
* **Validating optimizations** - Confirm that optimized backends match reference implementations

The evaluator's `debug_options` feature saves each node's computation as an executable `.exs` file, making it easy to reconstruct and compare tensors across backends.

## How It Works

When you enable `debug_options` with a `save_path`, the evaluator:

1. Saves each computation node as a separate `.exs` file
2. Serializes tensors as executable `Nx.from_binary()` calls
3. Preserves backend information, shape, type, and names
4. Creates files that can be directly executed to reconstruct tensors

This allows you to:

* Run the same computation with different backends
* Compare corresponding node outputs — in this guide we'll be using `Nx.all_close/2`
* Identify exactly where backends differ

## Simulating Backend Differences with Mimic

Instead of shipping a dedicated mock backend, we can use `Mimic.stub/3` to override individual callbacks on `Nx.BinaryBackend`. First we initialize `Mimic` and copy the binary backend so it can be stubbed safely. Then `add`, `multiply`, and `divide` are swapped to make the divergence easy to spot

```elixir

Mimic.copy(Nx.BinaryBackend)

defmodule BackendSwaps do
def enable! do
Mimic.stub(Nx.BinaryBackend, :add, fn out, left, right ->
Nx.BinaryBackend.subtract(out, left, right)
end)

Mimic.stub(Nx.BinaryBackend, :multiply, fn out, left, right ->
Nx.BinaryBackend.add(out, left, right)
end)

Mimic.stub(Nx.BinaryBackend, :divide, fn out, left, right ->
Nx.BinaryBackend.add(out, left, right)
end)
end

def restore! do
for fun <- [:add, :multiply, :divide] do
Mimic.stub(Nx.BinaryBackend, fun, fn out, left, right ->
Mimic.call_original(Nx.BinaryBackend, fun, [out, left, right])
end)
end
end
end
```

## Example: Simple Computation

Let's define a simple computation to compare across backends:

```elixir
defmodule SimpleComputation do
import Nx.Defn

defn compute(x, y) do
a = Nx.add(x, y)
b = Nx.multiply(a, 2)
Nx.divide(b, 3)
end
end
```

### Prepare Test Data

```elixir
# Create some test input
x = Nx.tensor([1.0, 2.0, 3.0, 4.0])
y = Nx.tensor([0.5, 1.5, 2.5, 3.5])

IO.puts("Input tensors:")
IO.inspect(x, label: "x")
IO.inspect(y, label: "y")
```

## Preparing the function for comparing

In order to ensure the same `id` for each node in the graph while our function traverses it on both backends, we need to use `Nx.Defn.debug_expr/1` to pre-compile `SimpleComputation.compute/2`.

This is a trick to make sure the same expression is passed on both `Nx.Defn.jit/2` calls and should not be used liberally.

```elixir
expr = Nx.Defn.debug_expr(&SimpleComputation.compute/2).(x, y)

precompiled = fn _x, _y -> expr end
```

## Running with Backend A

Let's run our computation with the first backend (BinaryBackend in this example, but could be any backend):

```elixir
# Clean up and create output directory
File.rm_rf!("/tmp/backend_a")
File.mkdir_p!("/tmp/backend_a")

# Run computation with debug output enabled
result_a = Nx.Defn.jit(
precompiled,
compiler: Nx.Defn.Evaluator,
debug_options: [save_path: "/tmp/backend_a"]
).(x, y)

IO.puts("\n✅ Backend A completed")
IO.inspect(result_a, label: "Result A")
IO.puts("Backend: #{inspect(result_a.data.__struct__)}")

# Show what files were generated
files_a = File.ls!("/tmp/backend_a")
IO.puts("\nGenerated #{length(files_a)} node files:")
Enum.each(files_a, &IO.puts(" - #{&1}"))
```

## Examining the Output Files

Let's look at what the `.exs` files contain:

```elixir
# Read and display one of the generated files
example_file = File.ls!("/tmp/backend_a") |> List.last()
content = File.read!(Path.join("/tmp/backend_a", example_file))

IO.puts("=== Content of #{example_file} ===\n")
IO.puts(content)
```

Notice the format:

* **Node ID** - Unique identifier for this computation node
* **Operation** - The operation being performed (e.g., `:add`, `:multiply`, `:parameter`)
* **Arguments** - List containing parameters and tensors as strings
* **Result** - Executable code that reconstructs the output tensor from binary

### Verifying Executability

Each `.exs` file is a self-contained Elixir script, so you can execute it directly:

```elixir
example_path = Path.join("/tmp/backend_a", example_file)
Code.eval_file(example_path)

```

## Running with Backend B

Now let's run the same computation with the swapped operations. We leave `Nx` on its default backend, but temporarily enable the Mimic stubs so the evaluator will capture the modified behaviour.

```elixir
# Clean up and create output directory for backend B
File.rm_rf!("/tmp/backend_b")
File.mkdir_p!("/tmp/backend_b")

BackendSwaps.enable!()

result_b =
Nx.Defn.jit(
precompiled,
compiler: Nx.Defn.Evaluator,
debug_options: [save_path: "/tmp/backend_b"]
).(x, y)

IO.puts("✅ Backend B completed")
IO.inspect(result_b, label: "Result B")
IO.puts("Backend: #{inspect(result_b.data.__struct__)}")

files_b = File.ls!("/tmp/backend_b")
IO.puts("\nGenerated #{length(files_b)} node files")

BackendSwaps.restore!()
```

## Comparing the Outputs

Now we inspect the generated `.exs` files, compare every node, and then summarise matches and mismatches.

```elixir
IO.puts("Comparing outputs from .exs files")
IO.puts(String.duplicate("-", 60))

files_a = File.ls!("/tmp/backend_a") |> Enum.sort()
files_b = File.ls!("/tmp/backend_b") |> Enum.sort()

IO.puts("Backend A generated #{length(files_a)} files")
IO.puts("Backend B generated #{length(files_b)} files")

comparison =
Enum.zip_with(files_a, files_b, fn file_a, file_b ->
{tensor_a, bindings_a} = Code.eval_file(Path.join("/tmp/backend_a", file_a))
{tensor_b, _bindings_b} = Code.eval_file(Path.join("/tmp/backend_b", file_b))

op = Keyword.get(bindings_a, :operation)
match? = Nx.all_close(tensor_a, tensor_b, atol: 1.0e-6) |> Nx.to_number() == 1

%{
operation: op,
tensor_a: tensor_a,
tensor_b: tensor_b,
match?: match?,
file_a: file_a,
file_b: file_b
}
end)

{matches, mismatches} = Enum.split_with(comparison, & &1.match?)

IO.puts("\n Summary:")
IO.puts(String.duplicate("-", 60))

if Enum.any?(matches) do
IO.puts("✅ Matched nodes (#{length(matches)}):")

Enum.each(matches, fn match ->
IO.puts(" - #{match.operation} (#{match.file_a})")
end)
else
IO.puts("\n❌ No nodes match!")
end

if Enum.any?(mismatches) do
IO.puts("\n❌ Mismatched nodes (#{length(mismatches)}):")

Enum.each(mismatches, fn mismatch ->
IO.puts("- #{mismatch.operation} (#{mismatch.file_a})")
IO.puts("Backend A")
IO.inspect(mismatch.tensor_a)
IO.puts("Backend B")
IO.inspect(mismatch.tensor_b)
end)
else
IO.puts("\n✅ All nodes match!")
end
```

With Mimic stubs in place, the evaluator’s debug artifacts clearly show where the divergence starts, making it straightforward to pinpoint inconsistent nodes between implementations, while the summary highlights both the matching and mismatching nodes.
113 changes: 91 additions & 22 deletions nx/lib/nx/defn/evaluator.ex
Original file line number Diff line number Diff line change
Expand Up @@ -535,36 +535,105 @@ defmodule Nx.Defn.Evaluator do
end

defp format_node_info(%Expr{id: id, op: op, args: args}, res, inspect_limit) do
args =
Enum.map(
args,
&inspect(&1, custom_options: [print_id: true], limit: inspect_limit)
)
id_str = :erlang.ref_to_list(id) |> List.to_string() |> String.replace(["#Ref<", ">"], "")

result_str = inspect(res, limit: inspect_limit)
inspect_opts =
case inspect_limit do
nil -> []
limit -> [limit: limit]
end

args_code =
args
|> Enum.map(fn arg ->
inspected =
arg
|> inspect(inspect_opts)
|> String.trim()

" #{inspect(inspected)}"
end)
|> Enum.join(",\n")

# Format result as serialized tensor
result_code = "result = #{serialize_tensor(res)}"

"""
node_id = "#{id_str}"
operation = "#{inspect(op)}"

import Inspect.Algebra
args = [
#{args_code}
]

# Result:
#{result_code}
"""
end

defp serialize_tensor(%Nx.Tensor{data: %Expr{id: id}} = _tensor) when is_reference(id) do
# This is an unevaluated expression, not a concrete tensor
# Show the Node ID so users can find which file contains this tensor
id_str = :erlang.ref_to_list(id) |> List.to_string() |> String.replace(["#Ref<", ">"], "")
"# See Node ID: #{id_str}"
end

defp serialize_tensor(%Nx.Tensor{data: %Expr{}} = _tensor) do
# Expression without a valid reference ID
"# <unevaluated expression>"
end

id = :erlang.ref_to_list(id) |> List.to_string() |> String.replace(["#Ref<", ">"], "")
defp serialize_tensor(%Nx.Tensor{} = tensor) do
# Get the backend information from the tensor's data
{backend, backend_opts} =
case tensor.data do
%backend_mod{} -> {backend_mod, []}
_ -> Nx.default_backend()
end

# Convert tensor to binary and get metadata
binary = Nx.to_binary(tensor)
type = tensor.type
shape = tensor.shape
names = tensor.names

# Format the binary as a binary literal
binary_str = inspect(binary, limit: :infinity)

# Build the executable Nx code
backend_str = "{#{inspect(backend)}, #{inspect(backend_opts)}}"

code = "Nx.from_binary(#{binary_str}, #{inspect(type)}, backend: #{backend_str})"

# Add reshape if needed (non-scalar)
code =
if shape != {} do
shape_str = inspect(shape)
code <> " |> Nx.reshape(#{shape_str})"
else
code
end

# Add rename if there are non-nil names
code =
if names != [] and Enum.any?(names, &(&1 != nil)) do
names_list = inspect(names)
code <> " |> Nx.rename(#{names_list})"
else
code
end

code
end

concat([
"Node ID: #{id}",
line(),
"Operation: #{inspect(op)}",
line(),
concat(Enum.intersperse(["Args:" | args], line())),
line(),
"Result:",
line(),
result_str
])
|> format(100)
|> IO.iodata_to_binary()
defp serialize_tensor(other) do
# For non-tensor values (numbers, tuples, etc.)
inspect(other)
end

defp save_node_info_to_file(save_path, id, node_info) do
sanitized_id = inspect(id) |> String.replace(~r/[^a-zA-Z0-9_]/, "_")
file = Path.join(save_path, "node_#{sanitized_id}.txt")
file = Path.join(save_path, "node_#{sanitized_id}.exs")
File.write!(file, node_info)
end
end
2 changes: 2 additions & 0 deletions nx/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ defmodule Nx.MixProject do
"guides/advanced/vectorization.livemd",
"guides/advanced/aggregation.livemd",
"guides/advanced/automatic_differentiation.livemd",
"guides/advanced/backend_comparison.livemd",
"guides/advanced/complex_fft.livemd",
"guides/exercises/exercises-1-20.livemd"
],
skip_undefined_reference_warnings_on: ["CHANGELOG.md"],
Expand Down