Skip to content

Commit 28e1c65

Browse files
committed
refactor: only return the chain in the public interface
1 parent d4641de commit 28e1c65

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ defmodule Nx.Defn.GraphSplitter do
99
Traverses the expression and splits it into stages.
1010
"""
1111
def traverse(expr, expr_split_fn \\ fn _ -> false end) do
12+
{chain, _, _} = traverse_and_return_cache(expr, expr_split_fn)
13+
chain
14+
end
15+
16+
@doc false
17+
def traverse_and_return_cache(expr, expr_split_fn) do
1218
# expression_chain is going to be a reverse-accumulation of {category, subexpr}
1319
# that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
1420
state = %{

nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ defmodule Nx.Defn.GraphSplitterTest do
2323
_ -> false
2424
end
2525

26-
{chain, cache, state} = GraphSplitter.traverse(expr, split_fn)
26+
{chain, cache, state} = GraphSplitter.traverse_and_return_cache(expr, split_fn)
2727

2828
assert [
2929
%Stage{
@@ -143,7 +143,7 @@ defmodule Nx.Defn.GraphSplitterTest do
143143
_ -> false
144144
end
145145

146-
{chain, cache, state} = GraphSplitter.traverse(expr, split_fn)
146+
{chain, cache, state} = GraphSplitter.traverse_and_return_cache(expr, split_fn)
147147

148148
assert [
149149
%Stage{
@@ -292,8 +292,7 @@ defmodule Nx.Defn.GraphSplitterTest do
292292
_ -> false
293293
end
294294

295-
assert {[%Stage{} = stage_0, %Stage{} = stage_1], _cache, _state} =
296-
GraphSplitter.traverse(expr, split_fn)
295+
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
297296

298297
[{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] =
299298
Enum.to_list(stage_0.arguments)
@@ -358,8 +357,7 @@ defmodule Nx.Defn.GraphSplitterTest do
358357
_ -> false
359358
end
360359

361-
assert {[%Stage{} = stage_0, %Stage{} = stage_1], _cache, _state} =
362-
GraphSplitter.traverse(expr, split_fn)
360+
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
363361

364362
[{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] =
365363
Enum.to_list(stage_0.arguments)

0 commit comments

Comments
 (0)