From a104a66b500043b4627eb65ba730b19a8a797e38 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 01:10:19 -0300 Subject: [PATCH 1/3] fix: support splitting on tuples directly --- nx/lib/nx/defn/graph.ex | 151 +++++++++++++++++++++++++++++++-- nx/test/nx/defn/graph_test.exs | 35 ++++++++ 2 files changed, 179 insertions(+), 7 deletions(-) diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex index 08e069733e..4f8d88aafc 100644 --- a/nx/lib/nx/defn/graph.ex +++ b/nx/lib/nx/defn/graph.ex @@ -304,8 +304,23 @@ defmodule Nx.Defn.Graph do # When we split, decide what to include in the stage and create parameter replacement {stage_expr, result_expr} = - case tensor_args do - [] -> + case {tensor_args, expr.data.op} do + # Special case: metadata operation with empty tensor_args should include the computation + {[], :metadata} -> + # For metadata operations, we want to compute the wrapped expression in this stage + [wrapped_expr, _metadata] = args + + # If the wrapped expression is a tuple, extract its elements for the stage + case wrapped_expr do + t when is_tuple(t) -> + {t, new_expr} + + %T{} -> + # Single tensor - include it in the stage + {{wrapped_expr}, new_expr} + end + + {[], _} -> # No intermediate computations - create a parameter for this split operation # The current expression will be computed in the next stage param = Expr.parameter(new_expr, map_size(state.args)) @@ -320,8 +335,38 @@ defmodule Nx.Defn.Graph do # Update state with parameter mapping if we created one state = - case tensor_args do - [] -> + case {tensor_args, expr.data.op} do + {[], :metadata} -> + # For metadata operations with tuples, register each tuple element + [wrapped_expr, _metadata] = args + + case wrapped_expr do + t when is_tuple(t) -> + # For tuple metadata operations, output the individual elements + tuple_elements = Tuple.to_list(t) + + # Register each element with its position so they become stage arguments + state = + tuple_elements + |> Enum.with_index() + |> Enum.reduce(state, fn {elem_expr, index}, state -> + %{ + state + | args: Map.put(state.args, elem_expr.data.id, {stage_id, index}) + } + end) + + state + + %T{} -> + # Single tensor - register with position 0 + %{ + state + | args: Map.put(state.args, wrapped_expr.data.id, {stage_id, 0}) + } + end + + {[], _} -> # Add parameter mapping and node replacement for the split operation # Extract the parameter from the tuple param = elem(stage_expr, 0) @@ -355,9 +400,15 @@ defmodule Nx.Defn.Graph do {expr, {Map.put(cache, id, expr), state}} end - defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do - {tuple, cache} = composite_eval(tuple, state, cache) - res = elem(tuple, i) + defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}} = expr, {cache, state}) do + {tuple, {cache, state}} = composite_eval(tuple, state, cache) + + res = + case tuple do + t when is_tuple(t) -> elem(t, i) + %T{} -> put_in(expr.data.args, [tuple, i]) + end + {res, {Map.put(cache, id, res), state}} end @@ -420,6 +471,92 @@ defmodule Nx.Defn.Graph do end end + defp rewrite_subtree( + %T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr, + state, + acc + ) do + case state.nodes_to_replace do + %{^id => res} -> + # This elem operation is being replaced + {res, put_in(acc.used_args[id], res)} + + _ -> + # Check if this elem operation references a split tuple element + {tuple_expr, acc} = rewrite_subtree(tuple_expr, state, acc) + + # If the tuple expression is a tuple, check if its elements are in state.args + case tuple_expr do + t when is_tuple(t) -> + tuple_elements = Tuple.to_list(t) + + case Enum.at(tuple_elements, index) do + %T{data: %Expr{id: elem_id}} = elem_expr -> + # Check if this element was registered as a split tuple element + case Map.get(state.args, elem_id) do + {_stage_id, _position} -> + # This element is from a split tuple, create a parameter for it + # Will be reindexed later + param = Expr.parameter(elem_expr, 0) + {param, put_in(acc.used_args[elem_id], param)} + + _ -> + # Regular elem operation + {put_in(expr.data.args, [tuple_expr, index]), acc} + end + + _ -> + # Regular elem operation + {put_in(expr.data.args, [tuple_expr, index]), acc} + end + + %T{type: {:tuple, _}} = _tuple_tensor -> + # Check if any elements in state.args have the same stage ID (indicating a tuple split) + stage_ids = + state.args + |> Enum.map(fn {_id, {stage_id, _pos}} -> stage_id end) + |> Enum.frequencies() + + # Find stage IDs that appear multiple times (indicating tuple elements) + tuple_stage_id = + Enum.find_value(stage_ids, fn + {stage_id, count} when count > 1 and stage_id != nil -> stage_id + _ -> nil + end) + + case tuple_stage_id do + nil -> + # No tuple split detected, regular elem operation + {put_in(expr.data.args, [tuple_expr, index]), acc} + + stage_id -> + # Tuple was split, create a parameter for this element + param = Expr.parameter(expr, index) + # We need to find the element ID that corresponds to this index + elem_id = + Enum.find_value(state.args, fn + {id, {^stage_id, ^index}} -> id + _ -> nil + end) + + case elem_id do + nil -> + # Couldn't find the element, fallback to regular elem + {put_in(expr.data.args, [tuple_expr, index]), acc} + + elem_id -> + # Found the element, create parameter and add to used_args + {param, put_in(acc.used_args[elem_id], param)} + end + end + + _ -> + # Regular elem operation + {put_in(expr.data.args, [tuple_expr, index]), acc} + end + end + end + defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do case state.nodes_to_replace do %{^id => res} -> diff --git a/nx/test/nx/defn/graph_test.exs b/nx/test/nx/defn/graph_test.exs index a4dee7f592..54ad0977d8 100644 --- a/nx/test/nx/defn/graph_test.exs +++ b/nx/test/nx/defn/graph_test.exs @@ -354,6 +354,41 @@ defmodule Nx.Defn.GraphTest do assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left assert %T{data: %Expr{op: :parameter, args: [1]}} = a end + + test "supports splitting on tuples with metadata" do + expr = + Nx.Defn.debug_expr(fn x -> + y = Nx.add(x, 1) + z = Nx.add(x, 2) + w = {Nx.add(y, 3), Nx.add(z, 4)} + {a, b} = Nx.Defn.Expr.metadata(w, %{split: true}) + Nx.add(a, b) + end).(Nx.tensor([1, 2, 3])) + + split_fn = fn + %T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert [%{source: {nil, 0}}] = stage_0.arguments + assert {add_y, add_z} = stage_0.expr + + assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [4]}}, y]}} = + add_y + + assert %T{data: %Expr{op: :parameter, args: [0]}} = y + + assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [6]}}, ^y]}} = + add_z + + assert stage_1.arguments == [%{source: {stage_0.id, 0}}, %{source: {stage_0.id, 1}}] + assert %T{data: %Expr{op: :add, args: [add_y, add_z]}} = stage_1.expr + + assert %T{data: %Expr{op: :parameter, args: [0]}} = add_y + assert %T{data: %Expr{op: :parameter, args: [1]}} = add_z + end end describe "split/3" do From ac2ab9dbd278a6f4ffce90563e0c34a933fda905 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 23:30:59 -0300 Subject: [PATCH 2/3] fix --- nx/lib/nx/defn/graph.ex | 184 +++++++++++++--------------------------- 1 file changed, 59 insertions(+), 125 deletions(-) diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex index 4f8d88aafc..d1f5110033 100644 --- a/nx/lib/nx/defn/graph.ex +++ b/nx/lib/nx/defn/graph.ex @@ -302,31 +302,24 @@ defmodule Nx.Defn.Graph do new_expr = put_in(expr.data.args, args) - # When we split, decide what to include in the stage and create parameter replacement + # Special handling: if we are splitting on metadata wrapping a tuple, make + # each tuple element the stage output and register element params for later stages. + metadata_wrapped_tuple? = + match?({:metadata, [wrapped, _]} when is_tuple(wrapped), {expr.data.op, args}) + {stage_expr, result_expr} = - case {tensor_args, expr.data.op} do - # Special case: metadata operation with empty tensor_args should include the computation - {[], :metadata} -> - # For metadata operations, we want to compute the wrapped expression in this stage - [wrapped_expr, _metadata] = args - - # If the wrapped expression is a tuple, extract its elements for the stage - case wrapped_expr do - t when is_tuple(t) -> - {t, new_expr} - - %T{} -> - # Single tensor - include it in the stage - {{wrapped_expr}, new_expr} - end - - {[], _} -> + cond do + metadata_wrapped_tuple? -> + [wrapped_expr, _] = args + {wrapped_expr, new_expr} + + tensor_args == [] -> # No intermediate computations - create a parameter for this split operation # The current expression will be computed in the next stage param = Expr.parameter(new_expr, map_size(state.args)) {{param}, param} - _ -> + true -> # There are intermediate computations - only include those in the current stage # The current expression will be computed in the next stage using these outputs stage_expr = List.to_tuple(Enum.reverse(tensor_args)) @@ -335,38 +328,28 @@ defmodule Nx.Defn.Graph do # Update state with parameter mapping if we created one state = - case {tensor_args, expr.data.op} do - {[], :metadata} -> - # For metadata operations with tuples, register each tuple element - [wrapped_expr, _metadata] = args - - case wrapped_expr do - t when is_tuple(t) -> - # For tuple metadata operations, output the individual elements - tuple_elements = Tuple.to_list(t) - - # Register each element with its position so they become stage arguments - state = - tuple_elements - |> Enum.with_index() - |> Enum.reduce(state, fn {elem_expr, index}, state -> - %{ - state - | args: Map.put(state.args, elem_expr.data.id, {stage_id, index}) - } - end) - + cond do + metadata_wrapped_tuple? -> + # Register each tuple element as a stage output and create a replacement parameter + [wrapped_expr, _] = args + + wrapped_expr + |> Tuple.to_list() + |> Enum.with_index() + |> Enum.reduce(state, fn {%T{} = elem_expr, index}, state -> + param = Expr.parameter(elem_expr, index) + + %{ state + | args: + state.args + |> Map.put(elem_expr.data.id, {stage_id, index}) + |> Map.put(param.data.id, {stage_id, index}), + nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param) + } + end) - %T{} -> - # Single tensor - register with position 0 - %{ - state - | args: Map.put(state.args, wrapped_expr.data.id, {stage_id, 0}) - } - end - - {[], _} -> + tensor_args == [] -> # Add parameter mapping and node replacement for the split operation # Extract the parameter from the tuple param = elem(stage_expr, 0) @@ -377,7 +360,7 @@ defmodule Nx.Defn.Graph do nodes_to_replace: Map.put(state.nodes_to_replace, new_expr.data.id, param) } - _ -> + true -> state end @@ -471,6 +454,19 @@ defmodule Nx.Defn.Graph do end end + defp rewrite_subtree(%T{data: %Expr{id: id, args: args, op: op}} = expr, state, acc) + when op != :elem do + case state.nodes_to_replace do + %{^id => res} -> + # nodes_to_replace always contains a param + {res, put_in(acc.used_args[id], res)} + + _ -> + {args, acc} = composite_rewrite_subtree(args, state, acc) + {put_in(expr.data.args, args), acc} + end + end + defp rewrite_subtree( %T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr, state, @@ -478,96 +474,34 @@ defmodule Nx.Defn.Graph do ) do case state.nodes_to_replace do %{^id => res} -> - # This elem operation is being replaced {res, put_in(acc.used_args[id], res)} _ -> - # Check if this elem operation references a split tuple element {tuple_expr, acc} = rewrite_subtree(tuple_expr, state, acc) - # If the tuple expression is a tuple, check if its elements are in state.args case tuple_expr do + # Literal tuple: turn elem into a parameter for that element t when is_tuple(t) -> - tuple_elements = Tuple.to_list(t) - - case Enum.at(tuple_elements, index) do - %T{data: %Expr{id: elem_id}} = elem_expr -> - # Check if this element was registered as a split tuple element - case Map.get(state.args, elem_id) do - {_stage_id, _position} -> - # This element is from a split tuple, create a parameter for it - # Will be reindexed later - param = Expr.parameter(elem_expr, 0) - {param, put_in(acc.used_args[elem_id], param)} - - _ -> - # Regular elem operation - {put_in(expr.data.args, [tuple_expr, index]), acc} - end + elem_expr = elem(t, index) + param = Expr.parameter(elem_expr, index) + {param, put_in(acc.used_args[elem_expr.data.id], param)} - _ -> - # Regular elem operation - {put_in(expr.data.args, [tuple_expr, index]), acc} - end + # Metadata-wrapped tuple: same as above + %T{data: %Expr{op: :metadata, args: [wrapped, _]}} when is_tuple(wrapped) -> + elem_expr = elem(wrapped, index) + param = Expr.parameter(elem_expr, index) + {param, put_in(acc.used_args[elem_expr.data.id], param)} - %T{type: {:tuple, _}} = _tuple_tensor -> - # Check if any elements in state.args have the same stage ID (indicating a tuple split) - stage_ids = - state.args - |> Enum.map(fn {_id, {stage_id, _pos}} -> stage_id end) - |> Enum.frequencies() - - # Find stage IDs that appear multiple times (indicating tuple elements) - tuple_stage_id = - Enum.find_value(stage_ids, fn - {stage_id, count} when count > 1 and stage_id != nil -> stage_id - _ -> nil - end) - - case tuple_stage_id do - nil -> - # No tuple split detected, regular elem operation - {put_in(expr.data.args, [tuple_expr, index]), acc} - - stage_id -> - # Tuple was split, create a parameter for this element - param = Expr.parameter(expr, index) - # We need to find the element ID that corresponds to this index - elem_id = - Enum.find_value(state.args, fn - {id, {^stage_id, ^index}} -> id - _ -> nil - end) - - case elem_id do - nil -> - # Couldn't find the element, fallback to regular elem - {put_in(expr.data.args, [tuple_expr, index]), acc} - - elem_id -> - # Found the element, create parameter and add to used_args - {param, put_in(acc.used_args[elem_id], param)} - end - end + # Tuple tensor: create a parameter pointing to this index + %T{type: {:tuple, _}} -> + param = Expr.parameter(expr, index) + {param, put_in(acc.used_args[param.data.id], param)} _ -> - # Regular elem operation {put_in(expr.data.args, [tuple_expr, index]), acc} end end end - defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do - case state.nodes_to_replace do - %{^id => res} -> - # nodes_to_replace always contains a param - {res, put_in(acc.used_args[id], res)} - - _ -> - {args, acc} = composite_rewrite_subtree(args, state, acc) - {put_in(expr.data.args, args), acc} - end - end - defp rewrite_subtree(other, _, acc), do: {other, acc} end From 114fe4e1796e04be77a67457e8b19f8deb98899b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 23:39:54 -0300 Subject: [PATCH 3/3] review --- nx/lib/nx/defn/graph.ex | 81 ++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex index d1f5110033..b7c113164e 100644 --- a/nx/lib/nx/defn/graph.ex +++ b/nx/lib/nx/defn/graph.ex @@ -302,24 +302,20 @@ defmodule Nx.Defn.Graph do new_expr = put_in(expr.data.args, args) - # Special handling: if we are splitting on metadata wrapping a tuple, make - # each tuple element the stage output and register element params for later stages. - metadata_wrapped_tuple? = - match?({:metadata, [wrapped, _]} when is_tuple(wrapped), {expr.data.op, args}) - {stage_expr, result_expr} = - cond do - metadata_wrapped_tuple? -> - [wrapped_expr, _] = args + case {tensor_args, expr.data.op, args} do + {_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) -> + # We're effectively splitting on a tuple, so we need to create a + # stage output for each element {wrapped_expr, new_expr} - tensor_args == [] -> + {[], _, _} -> # No intermediate computations - create a parameter for this split operation # The current expression will be computed in the next stage param = Expr.parameter(new_expr, map_size(state.args)) {{param}, param} - true -> + _ -> # There are intermediate computations - only include those in the current stage # The current expression will be computed in the next stage using these outputs stage_expr = List.to_tuple(Enum.reverse(tensor_args)) @@ -328,28 +324,30 @@ defmodule Nx.Defn.Graph do # Update state with parameter mapping if we created one state = - cond do - metadata_wrapped_tuple? -> + case {tensor_args, expr.data.op, args} do + {_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) -> # Register each tuple element as a stage output and create a replacement parameter - [wrapped_expr, _] = args + {state, _} = + wrapped_expr + |> Tuple.to_list() + |> Enum.reduce({state, 0}, fn %T{} = elem_expr, {state, index} -> + param = Expr.parameter(elem_expr, index) - wrapped_expr - |> Tuple.to_list() - |> Enum.with_index() - |> Enum.reduce(state, fn {%T{} = elem_expr, index}, state -> - param = Expr.parameter(elem_expr, index) + state = %{ + state + | args: + state.args + |> Map.put(elem_expr.data.id, {stage_id, index}) + |> Map.put(param.data.id, {stage_id, index}), + nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param) + } + + {state, index + 1} + end) - %{ - state - | args: - state.args - |> Map.put(elem_expr.data.id, {stage_id, index}) - |> Map.put(param.data.id, {stage_id, index}), - nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param) - } - end) + state - tensor_args == [] -> + {[], _, _} -> # Add parameter mapping and node replacement for the split operation # Extract the parameter from the tuple param = elem(stage_expr, 0) @@ -360,7 +358,7 @@ defmodule Nx.Defn.Graph do nodes_to_replace: Map.put(state.nodes_to_replace, new_expr.data.id, param) } - true -> + _ -> state end @@ -454,19 +452,6 @@ defmodule Nx.Defn.Graph do end end - defp rewrite_subtree(%T{data: %Expr{id: id, args: args, op: op}} = expr, state, acc) - when op != :elem do - case state.nodes_to_replace do - %{^id => res} -> - # nodes_to_replace always contains a param - {res, put_in(acc.used_args[id], res)} - - _ -> - {args, acc} = composite_rewrite_subtree(args, state, acc) - {put_in(expr.data.args, args), acc} - end - end - defp rewrite_subtree( %T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr, state, @@ -503,5 +488,17 @@ defmodule Nx.Defn.Graph do end end + defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do + case state.nodes_to_replace do + %{^id => res} -> + # nodes_to_replace always contains a param + {res, put_in(acc.used_args[id], res)} + + _ -> + {args, acc} = composite_rewrite_subtree(args, state, acc) + {put_in(expr.data.args, args), acc} + end + end + defp rewrite_subtree(other, _, acc), do: {other, acc} end