Skip to content

Commit d63bba1

Browse files
committed
refactor: return final results from Enum.reduce directly
1 parent b8dd06a commit d63bba1

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ defmodule Nx.Defn.GraphSplitter do
2121
Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end)
2222
|> Map.new()
2323

24-
scope =
25-
Enum.reduce(chain, scope, fn stage, scope ->
24+
{results, _scope} =
25+
Enum.reduce(chain, {nil, scope}, fn stage, {_results, scope} ->
2626
%{id: id, expr: expr, arguments: arguments} = stage
2727

2828
args =
@@ -32,31 +32,26 @@ defmodule Nx.Defn.GraphSplitter do
3232

3333
case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do
3434
%T{} = tensor ->
35-
Map.put(scope, {id, 0}, tensor)
35+
{tensor, Map.put(scope, {id, 0}, tensor)}
3636

3737
tuple ->
38-
{_idx, scope} =
38+
{_idx, scope, reverse_results} =
3939
tuple
4040
|> Tuple.to_list()
41-
|> Enum.reduce({0, scope}, fn tensor, {idx, scope} ->
42-
{idx + 1, Map.put(scope, {id, idx}, tensor)}
41+
|> Enum.reduce({0, scope, []}, fn tensor, {idx, scope, results_acc} ->
42+
{idx + 1, Map.put(scope, {id, idx}, tensor), [tensor | results_acc]}
4343
end)
4444

45-
scope
45+
{reverse_results, scope}
4646
end
4747
end)
4848

49-
last_stage = List.last(chain)
50-
51-
if is_tuple(last_stage.expr) do
52-
scope
53-
|> Enum.filter(fn {{id, _}, _} -> id == last_stage.id end)
54-
|> Enum.sort_by(fn {{_, idx}, _} -> idx end)
55-
|> Enum.map(fn {_, tensor} -> tensor end)
49+
if is_list(results) do
50+
results
51+
|> Enum.reverse()
5652
|> List.to_tuple()
5753
else
58-
{_, tensor} = Enum.find(scope, fn {{id, _}, _} -> id == last_stage.id end)
59-
tensor
54+
results
6055
end
6156
end
6257

0 commit comments

Comments
 (0)