Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 76 additions & 8 deletions nx/lib/nx/defn/graph.ex
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,14 @@ defmodule Nx.Defn.Graph do

new_expr = put_in(expr.data.args, args)

# When we split, decide what to include in the stage and create parameter replacement
{stage_expr, result_expr} =
case tensor_args do
[] ->
case {tensor_args, expr.data.op, args} do
{_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) ->
# We're effectively splitting on a tuple, so we need to create a
# stage output for each element
{wrapped_expr, new_expr}

{[], _, _} ->
# No intermediate computations - create a parameter for this split operation
# The current expression will be computed in the next stage
param = Expr.parameter(new_expr, map_size(state.args))
Expand All @@ -320,8 +324,30 @@ defmodule Nx.Defn.Graph do

# Update state with parameter mapping if we created one
state =
case tensor_args do
[] ->
case {tensor_args, expr.data.op, args} do
{_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) ->
# Register each tuple element as a stage output and create a replacement parameter
{state, _} =
wrapped_expr
|> Tuple.to_list()
|> Enum.reduce({state, 0}, fn %T{} = elem_expr, {state, index} ->
param = Expr.parameter(elem_expr, index)

state = %{
state
| args:
state.args
|> Map.put(elem_expr.data.id, {stage_id, index})
|> Map.put(param.data.id, {stage_id, index}),
nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param)
}

{state, index + 1}
end)

state

{[], _, _} ->
# Add parameter mapping and node replacement for the split operation
# Extract the parameter from the tuple
param = elem(stage_expr, 0)
Expand Down Expand Up @@ -355,9 +381,15 @@ defmodule Nx.Defn.Graph do
{expr, {Map.put(cache, id, expr), state}}
end

defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do
{tuple, cache} = composite_eval(tuple, state, cache)
res = elem(tuple, i)
defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}} = expr, {cache, state}) do
{tuple, {cache, state}} = composite_eval(tuple, state, cache)

res =
case tuple do
t when is_tuple(t) -> elem(t, i)
%T{} -> put_in(expr.data.args, [tuple, i])
end

{res, {Map.put(cache, id, res), state}}
end

Expand Down Expand Up @@ -420,6 +452,42 @@ defmodule Nx.Defn.Graph do
end
end

defp rewrite_subtree(
%T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr,
state,
acc
) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], res)}

_ ->
{tuple_expr, acc} = rewrite_subtree(tuple_expr, state, acc)

case tuple_expr do
# Literal tuple: turn elem into a parameter for that element
t when is_tuple(t) ->
elem_expr = elem(t, index)
param = Expr.parameter(elem_expr, index)
{param, put_in(acc.used_args[elem_expr.data.id], param)}

# Metadata-wrapped tuple: same as above
%T{data: %Expr{op: :metadata, args: [wrapped, _]}} when is_tuple(wrapped) ->
elem_expr = elem(wrapped, index)
param = Expr.parameter(elem_expr, index)
{param, put_in(acc.used_args[elem_expr.data.id], param)}

# Tuple tensor: create a parameter pointing to this index
%T{type: {:tuple, _}} ->
param = Expr.parameter(expr, index)
{param, put_in(acc.used_args[param.data.id], param)}

_ ->
{put_in(expr.data.args, [tuple_expr, index]), acc}
end
end
end

defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
Expand Down
35 changes: 35 additions & 0 deletions nx/test/nx/defn/graph_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,41 @@ defmodule Nx.Defn.GraphTest do
assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left
assert %T{data: %Expr{op: :parameter, args: [1]}} = a
end

test "supports splitting on tuples with metadata" do
expr =
Nx.Defn.debug_expr(fn x ->
y = Nx.add(x, 1)
z = Nx.add(x, 2)
w = {Nx.add(y, 3), Nx.add(z, 4)}
{a, b} = Nx.Defn.Expr.metadata(w, %{split: true})
Nx.add(a, b)
end).(Nx.tensor([1, 2, 3]))

split_fn = fn
%T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true
_ -> false
end

assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn)

assert [%{source: {nil, 0}}] = stage_0.arguments
assert {add_y, add_z} = stage_0.expr

assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [4]}}, y]}} =
add_y

assert %T{data: %Expr{op: :parameter, args: [0]}} = y

assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [6]}}, ^y]}} =
add_z

assert stage_1.arguments == [%{source: {stage_0.id, 0}}, %{source: {stage_0.id, 1}}]
assert %T{data: %Expr{op: :add, args: [add_y, add_z]}} = stage_1.expr

assert %T{data: %Expr{op: :parameter, args: [0]}} = add_y
assert %T{data: %Expr{op: :parameter, args: [1]}} = add_z
end
end

describe "split/3" do
Expand Down