@@ -294,19 +294,17 @@ defmodule Nx.Defn.GraphSplitterTest do
294294
295295 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
296296
297- [ { arg1_id , % T { shape: { 2 , 3 } , type: { :u , 8 } , data: % Expr { args: [ 0 ] } } } ] =
297+ [ { arg1_id , 0 } ] =
298298 Enum . to_list ( stage_0 . arguments )
299299
300300 assert stage_0 . argument_sources == % { arg1_id => { nil , 1 } }
301301
302302 stage_1_args =
303- Enum . sort_by ( stage_1 . arguments , fn { _id , % T { data: % Expr { op: :parameter , args: [ idx ] } } } ->
304- idx
305- end )
303+ Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
306304
307305 assert [
308- { arg_0_id , % T { shape: { } , type: { :s , 32 } } } ,
309- { arg_1_id , % T { shape: { 2 , 3 } , type: { :u , 8 } } }
306+ { arg_0_id , 0 } ,
307+ { arg_1_id , 1 }
310308 ] =
311309 stage_1_args
312310
@@ -359,21 +357,14 @@ defmodule Nx.Defn.GraphSplitterTest do
359357
360358 assert [ % Stage { } = stage_0 , % Stage { } = stage_1 ] = GraphSplitter . traverse ( expr , split_fn )
361359
362- [ { arg1_id , % T { shape: { 2 , 3 } , type: { :u , 8 } , data: % Expr { args: [ 0 ] } } } ] =
363- Enum . to_list ( stage_0 . arguments )
360+ [ { arg1_id , 0 } ] = Enum . to_list ( stage_0 . arguments )
364361
365362 assert stage_0 . argument_sources == % { arg1_id => { nil , 1 } }
366363
367364 stage_1_args =
368- Enum . sort_by ( stage_1 . arguments , fn { _id , % T { data: % Expr { op: :parameter , args: [ idx ] } } } ->
369- idx
370- end )
365+ Enum . sort_by ( stage_1 . arguments , fn { _id , idx } -> idx end )
371366
372- assert [
373- { arg_0_id , % T { shape: { } , type: { :s , 32 } } } ,
374- { arg_1_id , % T { shape: { 2 , 3 } , type: { :u , 8 } } }
375- ] =
376- stage_1_args
367+ assert [ { arg_0_id , 0 } , { arg_1_id , 1 } ] = stage_1_args
377368
378369 assert stage_1 . argument_sources == % { arg_0_id => { nil , 0 } , arg_1_id => { stage_0 . id , 0 } }
379370
0 commit comments