diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex index 08e069733e..b7c113164e 100644 --- a/nx/lib/nx/defn/graph.ex +++ b/nx/lib/nx/defn/graph.ex @@ -302,10 +302,14 @@ defmodule Nx.Defn.Graph do new_expr = put_in(expr.data.args, args) - # When we split, decide what to include in the stage and create parameter replacement {stage_expr, result_expr} = - case tensor_args do - [] -> + case {tensor_args, expr.data.op, args} do + {_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) -> + # We're effectively splitting on a tuple, so we need to create a + # stage output for each element + {wrapped_expr, new_expr} + + {[], _, _} -> # No intermediate computations - create a parameter for this split operation # The current expression will be computed in the next stage param = Expr.parameter(new_expr, map_size(state.args)) @@ -320,8 +324,30 @@ defmodule Nx.Defn.Graph do # Update state with parameter mapping if we created one state = - case tensor_args do - [] -> + case {tensor_args, expr.data.op, args} do + {_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) -> + # Register each tuple element as a stage output and create a replacement parameter + {state, _} = + wrapped_expr + |> Tuple.to_list() + |> Enum.reduce({state, 0}, fn %T{} = elem_expr, {state, index} -> + param = Expr.parameter(elem_expr, index) + + state = %{ + state + | args: + state.args + |> Map.put(elem_expr.data.id, {stage_id, index}) + |> Map.put(param.data.id, {stage_id, index}), + nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param) + } + + {state, index + 1} + end) + + state + + {[], _, _} -> # Add parameter mapping and node replacement for the split operation # Extract the parameter from the tuple param = elem(stage_expr, 0) @@ -355,9 +381,15 @@ defmodule Nx.Defn.Graph do {expr, {Map.put(cache, id, expr), state}} end - defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do - {tuple, cache} = composite_eval(tuple, state, cache) - res = elem(tuple, i) + defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}} = expr, {cache, state}) do + {tuple, {cache, state}} = composite_eval(tuple, state, cache) + + res = + case tuple do + t when is_tuple(t) -> elem(t, i) + %T{} -> put_in(expr.data.args, [tuple, i]) + end + {res, {Map.put(cache, id, res), state}} end @@ -420,6 +452,42 @@ defmodule Nx.Defn.Graph do end end + defp rewrite_subtree( + %T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr, + state, + acc + ) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], res)} + + _ -> + {tuple_expr, acc} = rewrite_subtree(tuple_expr, state, acc) + + case tuple_expr do + # Literal tuple: turn elem into a parameter for that element + t when is_tuple(t) -> + elem_expr = elem(t, index) + param = Expr.parameter(elem_expr, index) + {param, put_in(acc.used_args[elem_expr.data.id], param)} + + # Metadata-wrapped tuple: same as above + %T{data: %Expr{op: :metadata, args: [wrapped, _]}} when is_tuple(wrapped) -> + elem_expr = elem(wrapped, index) + param = Expr.parameter(elem_expr, index) + {param, put_in(acc.used_args[elem_expr.data.id], param)} + + # Tuple tensor: create a parameter pointing to this index + %T{type: {:tuple, _}} -> + param = Expr.parameter(expr, index) + {param, put_in(acc.used_args[param.data.id], param)} + + _ -> + {put_in(expr.data.args, [tuple_expr, index]), acc} + end + end + end + defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do case state.nodes_to_replace do %{^id => res} -> diff --git a/nx/test/nx/defn/graph_test.exs b/nx/test/nx/defn/graph_test.exs index a4dee7f592..54ad0977d8 100644 --- a/nx/test/nx/defn/graph_test.exs +++ b/nx/test/nx/defn/graph_test.exs @@ -354,6 +354,41 @@ defmodule Nx.Defn.GraphTest do assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left assert %T{data: %Expr{op: :parameter, args: [1]}} = a end + + test "supports splitting on tuples with metadata" do + expr = + Nx.Defn.debug_expr(fn x -> + y = Nx.add(x, 1) + z = Nx.add(x, 2) + w = {Nx.add(y, 3), Nx.add(z, 4)} + {a, b} = Nx.Defn.Expr.metadata(w, %{split: true}) + Nx.add(a, b) + end).(Nx.tensor([1, 2, 3])) + + split_fn = fn + %T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert [%{source: {nil, 0}}] = stage_0.arguments + assert {add_y, add_z} = stage_0.expr + + assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [4]}}, y]}} = + add_y + + assert %T{data: %Expr{op: :parameter, args: [0]}} = y + + assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [6]}}, ^y]}} = + add_z + + assert stage_1.arguments == [%{source: {stage_0.id, 0}}, %{source: {stage_0.id, 1}}] + assert %T{data: %Expr{op: :add, args: [add_y, add_z]}} = stage_1.expr + + assert %T{data: %Expr{op: :parameter, args: [0]}} = add_y + assert %T{data: %Expr{op: :parameter, args: [1]}} = add_z + end end describe "split/3" do