Skip to content

Commit 867d268

Browse files
committed
chore: simplify result accumulation
1 parent f7e01d1 commit 867d268

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

nx/lib/nx/defn/graph.ex

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ defmodule Nx.Defn.Graph do
6666
Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end)
6767
|> Map.new()
6868

69-
{results, _scope} =
70-
Enum.reduce(chain, {nil, scope}, fn stage, {_results, scope} ->
69+
{result, _scope} =
70+
Enum.reduce(chain, {nil, scope}, fn stage, {_result, scope} ->
7171
%{id: id, expr: expr, arguments: arguments} = stage
7272

7373
args =
@@ -80,24 +80,18 @@ defmodule Nx.Defn.Graph do
8080
{tensor, Map.put(scope, {id, 0}, tensor)}
8181

8282
tuple ->
83-
{_idx, scope, reverse_results} =
83+
{_idx, scope} =
8484
tuple
8585
|> Tuple.to_list()
86-
|> Enum.reduce({0, scope, []}, fn tensor, {idx, scope, results_acc} ->
87-
{idx + 1, Map.put(scope, {id, idx}, tensor), [tensor | results_acc]}
86+
|> Enum.reduce({0, scope}, fn tensor, {idx, scope} ->
87+
{idx + 1, Map.put(scope, {id, idx}, tensor)}
8888
end)
8989

90-
{reverse_results, scope}
90+
{tuple, scope}
9191
end
9292
end)
9393

94-
if is_list(results) do
95-
results
96-
|> Enum.reverse()
97-
|> List.to_tuple()
98-
else
99-
results
100-
end
94+
result
10195
end
10296

10397
@doc false

0 commit comments

Comments
 (0)