@@ -66,8 +66,8 @@ defmodule Nx.Defn.Graph do
6666 Enum . with_index ( args , fn arg , idx -> { { nil , idx } , arg } end )
6767 |> Map . new ( )
6868
69- { results , _scope } =
70- Enum . reduce ( chain , { nil , scope } , fn stage , { _results , scope } ->
69+ { result , _scope } =
70+ Enum . reduce ( chain , { nil , scope } , fn stage , { _result , scope } ->
7171 % { id: id , expr: expr , arguments: arguments } = stage
7272
7373 args =
@@ -80,24 +80,18 @@ defmodule Nx.Defn.Graph do
8080 { tensor , Map . put ( scope , { id , 0 } , tensor ) }
8181
8282 tuple ->
83- { _idx , scope , reverse_results } =
83+ { _idx , scope } =
8484 tuple
8585 |> Tuple . to_list ( )
86- |> Enum . reduce ( { 0 , scope , [ ] } , fn tensor , { idx , scope , results_acc } ->
87- { idx + 1 , Map . put ( scope , { id , idx } , tensor ) , [ tensor | results_acc ] }
86+ |> Enum . reduce ( { 0 , scope } , fn tensor , { idx , scope } ->
87+ { idx + 1 , Map . put ( scope , { id , idx } , tensor ) }
8888 end )
8989
90- { reverse_results , scope }
90+ { tuple , scope }
9191 end
9292 end )
9393
94- if is_list ( results ) do
95- results
96- |> Enum . reverse ( )
97- |> List . to_tuple ( )
98- else
99- results
100- end
94+ result
10195 end
10296
10397 @ doc false
0 commit comments