@@ -29,16 +29,16 @@ defmodule Nx.Defn.GraphSplitterTest do
2929 % Stage {
3030 id: stage_0_id ,
3131 expr: stage_0_expr ,
32- argument_sources: stage_0_argument_sources
32+ arguments: stage_0_arguments
3333 } ,
3434 % Stage {
3535 id: _stage_1_id ,
3636 expr: stage_1_expr ,
37- argument_sources: stage_1_argument_sources
37+ arguments: stage_1_arguments
3838 }
3939 ] = chain
4040
41- assert Enum . all? ( stage_0_argument_sources , fn { _id , { source_id , _idx } } ->
41+ assert Enum . all? ( stage_0_arguments , fn { _id , % { source: { source_id , _idx } } } ->
4242 source_id == nil
4343 end )
4444
@@ -52,10 +52,10 @@ defmodule Nx.Defn.GraphSplitterTest do
5252
5353 # ensure that arg2 and arg3 map to the correct stage and output container position
5454 assert % {
55- arg_2_id => { stage_0_id , 0 } ,
56- arg_3_id => { stage_0_id , 1 }
55+ arg_2_id => % { source: { stage_0_id , 0 } , index: 0 } ,
56+ arg_3_id => % { source: { stage_0_id , 1 } , index: 1 }
5757 } ==
58- stage_1_argument_sources
58+ stage_1_arguments
5959
6060 # ensure that arg2 and arg3 are replacing the correct nodes
6161 { _dot_node_id , % T { data: % Expr { args: [ dot_arg_0 , _ , _ , dot_arg_1 , _ , _ ] } } } =
@@ -149,21 +149,21 @@ defmodule Nx.Defn.GraphSplitterTest do
149149 % Stage {
150150 id: stage_0_id ,
151151 expr: stage_0_expr ,
152- argument_sources: stage_0_argument_sources
152+ arguments: stage_0_arguments
153153 } ,
154154 % Stage {
155155 id: stage_1_id ,
156156 expr: stage_1_expr ,
157- argument_sources: stage_1_argument_sources
157+ arguments: stage_1_arguments
158158 } ,
159159 % Stage {
160160 id: _stage_2_id ,
161161 expr: stage_2_expr ,
162- argument_sources: stage_2_argument_sources
162+ arguments: stage_2_arguments
163163 }
164164 ] = chain
165165
166- assert Enum . all? ( stage_0_argument_sources , fn { _id , { source_id , _idx } } ->
166+ assert Enum . all? ( stage_0_arguments , fn { _id , % { source: { source_id , _idx } } } ->
167167 source_id == nil
168168 end )
169169
@@ -200,10 +200,10 @@ defmodule Nx.Defn.GraphSplitterTest do
200200
201201 # ensure that arg3 and arg4 map to the correct stage and output container position
202202 assert % {
203- arg_3_id => { stage_0_id , 0 } ,
204- arg_4_id => { stage_0_id , 1 }
203+ arg_3_id => % { source: { stage_0_id , 0 } , index: 0 } ,
204+ arg_4_id => % { source: { stage_0_id , 1 } , index: 1 }
205205 } ==
206- stage_1_argument_sources
206+ stage_1_arguments
207207
208208 # ensure that arg3 and arg4 are replacing the correct nodes
209209 { _dot_node_id , % T { data: % Expr { args: [ dot_arg_0 , _ , _ , dot_arg_1 , _ , _ ] } } } =
@@ -269,7 +269,10 @@ defmodule Nx.Defn.GraphSplitterTest do
269269 assert % T { data: % Expr { op: :sum , args: [ ^ a , [ axes: nil , keep_axes: false ] ] } } = b
270270 assert % T { data: % Expr { id: ^ arg_5_id , op: :parameter , args: [ 1 ] } } = a
271271
272- assert % { arg_2_id => { nil , 2 } , arg_5_id => { stage_1_id , 0 } } == stage_2_argument_sources
272+ assert % {
273+ arg_2_id => % { source: { nil , 2 } , index: 0 } ,
274+ arg_5_id => % { source: { stage_1_id , 0 } , index: 1 }
275+ } == stage_2_arguments
273276 end
274277
275278 test "supports optional callbacks" do
@@ -294,21 +297,26 @@ defmodule Nx.Defn.GraphSplitterTest do
294297
295298 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
296299
297- [ { arg1_id , 0 } ] =
300+ [ { arg1_id , % { source: { nil , 1 } , index: 0 } } ] =
298301 Enum . to_list ( stage_0 . arguments )
299302
300- assert stage_0 . argument_sources == % { arg1_id => { nil , 1 } }
303+ assert stage_0 . arguments == % { arg1_id => % { source: { nil , 1 } , index: 0 } }
301304
302305 stage_1_args =
303306 Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
304307
308+ stage_0_id = stage_0 . id
309+
305310 assert [
306- { arg_0_id , 0 } ,
307- { arg_1_id , 1 }
311+ { arg_0_id , % { source: { nil , 0 } , index: 0 } } ,
312+ { arg_1_id , % { source: { ^ stage_0_id , 0 } , index: 1 } }
308313 ] =
309314 stage_1_args
310315
311- assert stage_1 . argument_sources == % { arg_0_id => { nil , 0 } , arg_1_id => { stage_0 . id , 0 } }
316+ assert stage_1 . arguments == % {
317+ arg_0_id => % { source: { nil , 0 } , index: 0 } ,
318+ arg_1_id => % { source: { stage_0 . id , 0 } , index: 1 }
319+ }
312320
313321 assert % T { data: % Expr { op: :subtract , args: [ c , d ] } } = stage_1 . expr
314322 assert % T { data: % Expr { op: :optional , args: [ call , subexpr , _fun ] } } = c
@@ -357,16 +365,21 @@ defmodule Nx.Defn.GraphSplitterTest do
357365
358366 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
359367
360- [ { arg1_id , 0 } ] = Enum . to_list ( stage_0 . arguments )
361-
362- assert stage_0 . argument_sources == % { arg1_id => { nil , 1 } }
368+ assert [ { stage_0_arg_0_id , % { source: { nil , 1 } , index: 0 } } ] = Enum . to_list ( stage_0 . arguments )
363369
364370 stage_1_args =
365371 Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
366372
367- assert [ { arg_0_id , 0 } , { arg_1_id , 1 } ] = stage_1_args
373+ stage_0_id = stage_0 . id
374+
375+ assert [
376+ { arg_0_id , % { source: { nil , 0 } , index: 0 } } ,
377+ { arg_1_id , % { source: { ^ stage_0_id , 0 } , index: 1 } }
378+ ] = stage_1_args
368379
369- assert stage_1 . argument_sources == % { arg_0_id => { nil , 0 } , arg_1_id => { stage_0 . id , 0 } }
380+ assert arg_0_id != arg_1_id
381+ assert arg_0_id != stage_0_arg_0_id
382+ assert arg_1_id != stage_0_arg_0_id
370383
371384 assert % T { data: % Expr { op: :subtract , args: [ c , d ] } } = stage_1 . expr
372385
@@ -439,8 +452,8 @@ defmodule Nx.Defn.GraphSplitterTest do
439452
440453 assert % T { data: % Expr { id: arg1_left_id , op: :parameter , args: [ 0 ] } } = arg1_left
441454
442- assert left . argument_sources [ arg1_left_id ] == { nil , 1 }
443- assert left . argument_sources [ x_left_id ] == { root . id , 0 }
455+ assert left . arguments [ arg1_left_id ] . source == { nil , 1 }
456+ assert left . arguments [ x_left_id ] . source == { root . id , 0 }
444457
445458 # right should depend on the result of the root and on arg1, but arg1 will be reindexed
446459 # we should assert that the argument source for arg1_right is correct
@@ -458,8 +471,8 @@ defmodule Nx.Defn.GraphSplitterTest do
458471
459472 assert % T { data: % Expr { id: arg1_right_id , op: :parameter , args: [ 0 ] } } = arg1_right
460473
461- assert right . argument_sources [ arg1_right_id ] == { nil , 1 }
462- assert right . argument_sources [ x_right_id ] == { root . id , 0 }
474+ assert right . arguments [ arg1_right_id ] . source == { nil , 1 }
475+ assert right . arguments [ x_right_id ] . source == { root . id , 0 }
463476
464477 assert % T { data: % Expr { op: :add , args: [ w_right , w_left ] } } = merge . expr
465478
@@ -483,8 +496,8 @@ defmodule Nx.Defn.GraphSplitterTest do
483496 }
484497 } = w_left
485498
486- assert merge . argument_sources [ w_right_id ] == { right . id , 0 }
487- assert merge . argument_sources [ w_left_id ] == { left . id , 0 }
499+ assert merge . arguments [ w_right_id ] . source == { right . id , 0 }
500+ assert merge . arguments [ w_left_id ] . source == { left . id , 0 }
488501
489502 assert GraphSplitter . run ( chain , args ) == expected_result
490503 end
0 commit comments