@@ -21,8 +21,8 @@ defmodule Nx.Defn.GraphSplitter do
2121 Enum . with_index ( args , fn arg , idx -> { { nil , idx } , arg } end )
2222 |> Map . new ( )
2323
24- scope =
25- Enum . reduce ( chain , scope , fn stage , scope ->
24+ { results , _scope } =
25+ Enum . reduce ( chain , { nil , scope } , fn stage , { _results , scope } ->
2626 % { id: id , expr: expr , arguments: arguments } = stage
2727
2828 args =
@@ -32,31 +32,26 @@ defmodule Nx.Defn.GraphSplitter do
3232
3333 case Nx.Defn . jit_apply ( fn _ -> expr end , [ List . to_tuple ( args ) ] ) do
3434 % T { } = tensor ->
35- Map . put ( scope , { id , 0 } , tensor )
35+ { tensor , Map . put ( scope , { id , 0 } , tensor ) }
3636
3737 tuple ->
38- { _idx , scope } =
38+ { _idx , scope , reverse_results } =
3939 tuple
4040 |> Tuple . to_list ( )
41- |> Enum . reduce ( { 0 , scope } , fn tensor , { idx , scope } ->
42- { idx + 1 , Map . put ( scope , { id , idx } , tensor ) }
41+ |> Enum . reduce ( { 0 , scope , [ ] } , fn tensor , { idx , scope , results_acc } ->
42+ { idx + 1 , Map . put ( scope , { id , idx } , tensor ) , [ tensor | results_acc ] }
4343 end )
4444
45- scope
45+ { reverse_results , scope }
4646 end
4747 end )
4848
49- last_stage = List . last ( chain )
50-
51- if is_tuple ( last_stage . expr ) do
52- scope
53- |> Enum . filter ( fn { { id , _ } , _ } -> id == last_stage . id end )
54- |> Enum . sort_by ( fn { { _ , idx } , _ } -> idx end )
55- |> Enum . map ( fn { _ , tensor } -> tensor end )
49+ if is_list ( results ) do
50+ results
51+ |> Enum . reverse ( )
5652 |> List . to_tuple ( )
5753 else
58- { _ , tensor } = Enum . find ( scope , fn { { id , _ } , _ } -> id == last_stage . id end )
59- tensor
54+ results
6055 end
6156 end
6257
0 commit comments