Skip to content

Commit 384c770

Browse files
committed
feat: add GraphSplitter.run/2
1 parent e3c68d0 commit 384c770

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,58 @@ defmodule Nx.Defn.GraphSplitter do
1313
chain
1414
end
1515

16+
@doc """
17+
Executes the stage chain with the given arguments.
18+
"""
19+
def run(chain, args) do
20+
scope =
21+
Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end)
22+
|> Map.new()
23+
24+
scope =
25+
Enum.reduce(chain, scope, fn stage, scope ->
26+
%{id: id, expr: expr, argument_sources: argument_sources, arguments: arguments} = stage
27+
28+
args =
29+
arguments
30+
|> Enum.map(fn {id, %T{data: %Expr{args: [idx]}}} ->
31+
source = Map.fetch!(argument_sources, id)
32+
argument = Map.fetch!(scope, source)
33+
{idx, argument}
34+
end)
35+
|> Enum.sort_by(fn {idx, _} -> idx end)
36+
|> Enum.map(fn {_, argument} -> argument end)
37+
38+
case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do
39+
%T{} = tensor ->
40+
Map.put(scope, {id, 0}, tensor)
41+
42+
tuple ->
43+
{_idx, scope} =
44+
tuple
45+
|> Tuple.to_list()
46+
|> Enum.reduce({0, scope}, fn tensor, {idx, scope} ->
47+
{idx + 1, Map.put(scope, {id, idx}, tensor)}
48+
end)
49+
50+
scope
51+
end
52+
end)
53+
54+
last_stage = List.last(chain)
55+
56+
if is_tuple(last_stage.expr) do
57+
scope
58+
|> Enum.filter(fn {{id, _}, _} -> id == last_stage.id end)
59+
|> Enum.sort_by(fn {{_, idx}, _} -> idx end)
60+
|> Enum.map(fn {_, tensor} -> tensor end)
61+
|> List.to_tuple()
62+
else
63+
{_, tensor} = Enum.find(scope, fn {{id, _}, _} -> id == last_stage.id end)
64+
tensor
65+
end
66+
end
67+
1668
@doc false
1769
def traverse_and_return_cache(expr, expr_split_fn) do
1870
# expression_chain is going to be a reverse-accumulation of {category, subexpr}

nx/test/nx/defn/graph_splitter_test.exs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,4 +395,113 @@ defmodule Nx.Defn.GraphSplitterTest do
395395
assert %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} = a
396396
end
397397
end
398+
399+
describe "run/2" do
400+
test "executes the stages chain and returns the correct result" do
401+
function = fn arg0, arg1 ->
402+
# root
403+
x = Nx.multiply(arg0, arg1) |> Nx.Defn.Expr.metadata(%{split: true})
404+
405+
# left side
406+
w_left = Nx.multiply(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true})
407+
408+
# right side
409+
w_right = Nx.divide(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true})
410+
411+
# merge
412+
Nx.add(w_right, w_left)
413+
end
414+
415+
args = [Nx.tensor([1, 2]), Nx.tensor([3, 4])]
416+
417+
# This is used in the final assertion of this test
418+
expected_result = Nx.Defn.jit_apply(function, args)
419+
420+
expr = apply(Nx.Defn.debug_expr(function), args)
421+
422+
split_fn = fn
423+
%T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true
424+
_ -> false
425+
end
426+
427+
chain = GraphSplitter.traverse(expr, split_fn)
428+
429+
assert [root, side1, side2, merge] = chain
430+
431+
assert {%T{data: %Expr{op: :multiply, args: [arg0, arg1]}}} = root.expr
432+
assert %T{data: %Expr{op: :parameter, args: [0]}} = arg0
433+
assert %T{data: %Expr{op: :parameter, args: [1]}} = arg1
434+
435+
# because things are balanced, we don't know which of side1 and side2 are left and right
436+
# in our expr, so we should disambiguate:
437+
438+
{[%Stage{} = left], [%Stage{} = right]} =
439+
Enum.split_with([side1, side2], fn %Stage{expr: {expr}} -> expr.data.op == :multiply end)
440+
441+
# left should depend on exactly the same parameters as the root, as it's pulling from
442+
# the global scope
443+
assert {%T{data: %Expr{op: :multiply, args: [x, arg1_left]}}} = left.expr
444+
445+
assert %T{
446+
data: %Expr{
447+
op: :metadata,
448+
args: [
449+
%T{data: %Expr{id: x_left_id, op: :parameter, args: [1]}},
450+
%{split: true}
451+
]
452+
}
453+
} = x
454+
455+
assert %T{data: %Expr{id: arg1_left_id, op: :parameter, args: [0]}} = arg1_left
456+
457+
assert left.argument_sources[arg1_left_id] == {nil, 1}
458+
assert left.argument_sources[x_left_id] == {root.id, 0}
459+
460+
# right should depend on the result of the root and on arg1, but arg1 will be reindexed
461+
# we should assert that the argument source for arg1_right is correct
462+
assert {%T{data: %Expr{op: :divide, args: [x, arg1_right]}}} = right.expr
463+
464+
assert %T{
465+
data: %Expr{
466+
op: :metadata,
467+
args: [
468+
%T{data: %Expr{id: x_right_id, op: :parameter, args: [1]}},
469+
%{split: true}
470+
]
471+
}
472+
} = x
473+
474+
assert %T{data: %Expr{id: arg1_right_id, op: :parameter, args: [0]}} = arg1_right
475+
476+
assert right.argument_sources[arg1_right_id] == {nil, 1}
477+
assert right.argument_sources[x_right_id] == {root.id, 0}
478+
479+
assert %T{data: %Expr{op: :add, args: [w_right, w_left]}} = merge.expr
480+
481+
assert %T{
482+
data: %Expr{
483+
op: :metadata,
484+
args: [
485+
%T{data: %Expr{id: w_right_id, op: :parameter, args: [0]}},
486+
%{split: true}
487+
]
488+
}
489+
} = w_right
490+
491+
assert %T{
492+
data: %Expr{
493+
op: :metadata,
494+
args: [
495+
%T{data: %Expr{id: w_left_id, op: :parameter, args: [1]}},
496+
%{split: true}
497+
]
498+
}
499+
} = w_left
500+
501+
assert merge.argument_sources[w_right_id] == {right.id, 0}
502+
assert merge.argument_sources[w_left_id] == {left.id, 0}
503+
504+
assert GraphSplitter.run(chain, args) == expected_result
505+
end
506+
end
398507
end

0 commit comments

Comments
 (0)