@@ -38,9 +38,7 @@ defmodule Nx.Defn.GraphSplitterTest do
3838 }
3939 ] = chain
4040
41- assert Enum . all? ( stage_0_arguments , fn { _id , % { source: { source_id , _idx } } } ->
42- source_id == nil
43- end )
41+ assert [ % { source: { nil , 0 } } , % { source: { nil , 1 } } ] == stage_0_arguments
4442
4543 assert [ { 2 , arg_2_original_node_id , arg_2_id } , { 3 , arg_3_original_node_id , arg_3_id } ] =
4644 state . nodes_to_replace
@@ -51,11 +49,7 @@ defmodule Nx.Defn.GraphSplitterTest do
5149 |> Enum . sort ( )
5250
5351 # ensure that arg2 and arg3 map to the correct stage and output container position
54- assert % {
55- arg_2_id => % { source: { stage_0_id , 0 } , index: 0 } ,
56- arg_3_id => % { source: { stage_0_id , 1 } , index: 1 }
57- } ==
58- stage_1_arguments
52+ assert [ % { source: { stage_0_id , 0 } } , % { source: { stage_0_id , 1 } } ] == stage_1_arguments
5953
6054 # ensure that arg2 and arg3 are replacing the correct nodes
6155 { _dot_node_id , % T { data: % Expr { args: [ dot_arg_0 , _ , _ , dot_arg_1 , _ , _ ] } } } =
@@ -163,9 +157,7 @@ defmodule Nx.Defn.GraphSplitterTest do
163157 }
164158 ] = chain
165159
166- assert Enum . all? ( stage_0_arguments , fn { _id , % { source: { source_id , _idx } } } ->
167- source_id == nil
168- end )
160+ assert [ % { source: { nil , 0 } } , % { source: { nil , 1 } } ] == stage_0_arguments
169161
170162 assert map_size ( state . args ) == 6
171163
@@ -199,11 +191,7 @@ defmodule Nx.Defn.GraphSplitterTest do
199191 assert arg_5_id not in original_args
200192
201193 # ensure that arg3 and arg4 map to the correct stage and output container position
202- assert % {
203- arg_3_id => % { source: { stage_0_id , 0 } , index: 0 } ,
204- arg_4_id => % { source: { stage_0_id , 1 } , index: 1 }
205- } ==
206- stage_1_arguments
194+ assert [ % { source: { stage_0_id , 0 } } , % { source: { stage_0_id , 1 } } ] == stage_1_arguments
207195
208196 # ensure that arg3 and arg4 are replacing the correct nodes
209197 { _dot_node_id , % T { data: % Expr { args: [ dot_arg_0 , _ , _ , dot_arg_1 , _ , _ ] } } } =
@@ -269,10 +257,7 @@ defmodule Nx.Defn.GraphSplitterTest do
269257 assert % T { data: % Expr { op: :sum , args: [ ^ a , [ axes: nil , keep_axes: false ] ] } } = b
270258 assert % T { data: % Expr { id: ^ arg_5_id , op: :parameter , args: [ 1 ] } } = a
271259
272- assert % {
273- arg_2_id => % { source: { nil , 2 } , index: 0 } ,
274- arg_5_id => % { source: { stage_1_id , 0 } , index: 1 }
275- } == stage_2_arguments
260+ assert [ % { source: { nil , 2 } } , % { source: { stage_1_id , 0 } } ] == stage_2_arguments
276261 end
277262
278263 test "supports optional callbacks" do
@@ -297,35 +282,17 @@ defmodule Nx.Defn.GraphSplitterTest do
297282
298283 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
299284
300- [ { arg1_id , % { source: { nil , 1 } , index: 0 } } ] =
301- Enum . to_list ( stage_0 . arguments )
302-
303- assert stage_0 . arguments == % { arg1_id => % { source: { nil , 1 } , index: 0 } }
304-
305- stage_1_args =
306- Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
307-
308- stage_0_id = stage_0 . id
309-
310- assert [
311- { arg_0_id , % { source: { nil , 0 } , index: 0 } } ,
312- { arg_1_id , % { source: { ^ stage_0_id , 0 } , index: 1 } }
313- ] =
314- stage_1_args
315-
316- assert stage_1 . arguments == % {
317- arg_0_id => % { source: { nil , 0 } , index: 0 } ,
318- arg_1_id => % { source: { stage_0 . id , 0 } , index: 1 }
319- }
285+ assert stage_0 . arguments == [ % { source: { nil , 1 } } ]
286+ assert stage_1 . arguments == [ % { source: { nil , 0 } } , % { source: { stage_0 . id , 0 } } ]
320287
321288 assert % T { data: % Expr { op: :subtract , args: [ c , d ] } } = stage_1 . expr
322289 assert % T { data: % Expr { op: :optional , args: [ call , subexpr , _fun ] } } = c
323290
324- assert % T { data: % Expr { id: ^ arg_0_id , op: :parameter , args: [ 0 ] } } = d
291+ assert % T { data: % Expr { id: arg_0_id , op: :parameter , args: [ 0 ] } } = d
325292
326293 assert % T { data: % Expr { op: :logical_not , args: [ b ] } } = call
327294 assert % T { data: % Expr { op: :sum , args: [ a , [ axes: [ 1 ] , keep_axes: false ] ] } } = b
328- assert % T { data: % Expr { id: ^ arg_1_id , op: :parameter , args: [ 1 ] } } = a
295+ assert % T { data: % Expr { id: arg_1_id , op: :parameter , args: [ 1 ] } } = a
329296
330297 assert % T {
331298 data: % Expr {
@@ -365,22 +332,9 @@ defmodule Nx.Defn.GraphSplitterTest do
365332
366333 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
367334
368- assert [ { stage_0_arg_0_id , % { source: { nil , 1 } , index: 0 } } ] = Enum . to_list ( stage_0 . arguments )
369-
370- stage_1_args =
371- Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
372-
373- stage_0_id = stage_0 . id
374-
375- assert [
376- { arg_0_id , % { source: { nil , 0 } , index: 0 } } ,
377- { arg_1_id , % { source: { ^ stage_0_id , 0 } , index: 1 } }
378- ] = stage_1_args
379-
380- assert arg_0_id != arg_1_id
381- assert arg_0_id != stage_0_arg_0_id
382- assert arg_1_id != stage_0_arg_0_id
335+ assert [ % { source: { nil , 1 } } ] == stage_0 . arguments
383336
337+ assert [ % { source: { nil , 0 } } , % { source: { stage_0 . id , 0 } } ] == stage_1 . arguments
384338 assert % T { data: % Expr { op: :subtract , args: [ c , d ] } } = stage_1 . expr
385339
386340 assert % T {
@@ -393,10 +347,10 @@ defmodule Nx.Defn.GraphSplitterTest do
393347 }
394348 } = c
395349
396- assert % T { data: % Expr { id: ^ arg_0_id , op: :parameter , args: [ 0 ] } } = d
350+ assert % T { data: % Expr { op: :parameter , args: [ 0 ] } } = d
397351
398352 assert % T { data: % Expr { op: :sum , args: [ a , [ axes: [ 1 ] , keep_axes: false ] ] } } = left
399- assert % T { data: % Expr { id: ^ arg_1_id , op: :parameter , args: [ 1 ] } } = a
353+ assert % T { data: % Expr { op: :parameter , args: [ 1 ] } } = a
400354 end
401355 end
402356
@@ -444,16 +398,16 @@ defmodule Nx.Defn.GraphSplitterTest do
444398 data: % Expr {
445399 op: :metadata ,
446400 args: [
447- % T { data: % Expr { id: x_left_id , op: :parameter , args: [ 1 ] } } ,
401+ % T { data: % Expr { op: :parameter , args: [ 1 ] } } ,
448402 % { split: true }
449403 ]
450404 }
451405 } = x
452406
453- assert % T { data: % Expr { id: arg1_left_id , op: :parameter , args: [ 0 ] } } = arg1_left
407+ assert % T { data: % Expr { op: :parameter , args: [ 0 ] } } = arg1_left
454408
455- assert left . arguments [ arg1_left_id ] . source == { nil , 1 }
456- assert left . arguments [ x_left_id ] . source == { root . id , 0 }
409+ assert Enum . fetch! ( left . arguments , 0 ) . source == { nil , 1 }
410+ assert Enum . fetch! ( left . arguments , 1 ) . source == { root . id , 0 }
457411
458412 # right should depend on the result of the root and on arg1, but arg1 will be reindexed
459413 # we should assert that the argument source for arg1_right is correct
@@ -463,24 +417,24 @@ defmodule Nx.Defn.GraphSplitterTest do
463417 data: % Expr {
464418 op: :metadata ,
465419 args: [
466- % T { data: % Expr { id: x_right_id , op: :parameter , args: [ 1 ] } } ,
420+ % T { data: % Expr { op: :parameter , args: [ 1 ] } } ,
467421 % { split: true }
468422 ]
469423 }
470424 } = x
471425
472- assert % T { data: % Expr { id: arg1_right_id , op: :parameter , args: [ 0 ] } } = arg1_right
426+ assert % T { data: % Expr { op: :parameter , args: [ 0 ] } } = arg1_right
473427
474- assert right . arguments [ arg1_right_id ] . source == { nil , 1 }
475- assert right . arguments [ x_right_id ] . source == { root . id , 0 }
428+ assert Enum . fetch! ( right . arguments , 0 ) . source == { nil , 1 }
429+ assert Enum . fetch! ( right . arguments , 1 ) . source == { root . id , 0 }
476430
477431 assert % T { data: % Expr { op: :add , args: [ w_right , w_left ] } } = merge . expr
478432
479433 assert % T {
480434 data: % Expr {
481435 op: :metadata ,
482436 args: [
483- % T { data: % Expr { id: w_right_id , op: :parameter , args: [ 0 ] } } ,
437+ % T { data: % Expr { op: :parameter , args: [ 0 ] } } ,
484438 % { split: true }
485439 ]
486440 }
@@ -490,14 +444,14 @@ defmodule Nx.Defn.GraphSplitterTest do
490444 data: % Expr {
491445 op: :metadata ,
492446 args: [
493- % T { data: % Expr { id: w_left_id , op: :parameter , args: [ 1 ] } } ,
447+ % T { data: % Expr { op: :parameter , args: [ 1 ] } } ,
494448 % { split: true }
495449 ]
496450 }
497451 } = w_left
498452
499- assert merge . arguments [ w_right_id ] . source == { right . id , 0 }
500- assert merge . arguments [ w_left_id ] . source == { left . id , 0 }
453+ assert Enum . fetch! ( merge . arguments , 0 ) . source == { right . id , 0 }
454+ assert Enum . fetch! ( merge . arguments , 1 ) . source == { left . id , 0 }
501455
502456 assert GraphSplitter . run ( chain , args ) == expected_result
503457 end
0 commit comments