@@ -395,4 +395,113 @@ defmodule Nx.Defn.GraphSplitterTest do
395395 assert % T { data: % Expr { id: ^ arg_1_id , op: :parameter , args: [ 1 ] } } = a
396396 end
397397 end
398+
399+ describe "run/2" do
400+ test "executes the stages chain and returns the correct result" do
401+ function = fn arg0 , arg1 ->
402+ # root
403+ x = Nx . multiply ( arg0 , arg1 ) |> Nx.Defn.Expr . metadata ( % { split: true } )
404+
405+ # left side
406+ w_left = Nx . multiply ( x , arg1 ) |> Nx.Defn.Expr . metadata ( % { split: true } )
407+
408+ # right side
409+ w_right = Nx . divide ( x , arg1 ) |> Nx.Defn.Expr . metadata ( % { split: true } )
410+
411+ # merge
412+ Nx . add ( w_right , w_left )
413+ end
414+
415+ args = [ Nx . tensor ( [ 1 , 2 ] ) , Nx . tensor ( [ 3 , 4 ] ) ]
416+
417+ # This is used in the final assertion of this test
418+ expected_result = Nx.Defn . jit_apply ( function , args )
419+
420+ expr = apply ( Nx.Defn . debug_expr ( function ) , args )
421+
422+ split_fn = fn
423+ % T { data: % Expr { op: :metadata , args: [ _expr , % { split: true } ] } } -> true
424+ _ -> false
425+ end
426+
427+ chain = GraphSplitter . traverse ( expr , split_fn )
428+
429+ assert [ root , side1 , side2 , merge ] = chain
430+
431+ assert { % T { data: % Expr { op: :multiply , args: [ arg0 , arg1 ] } } } = root . expr
432+ assert % T { data: % Expr { op: :parameter , args: [ 0 ] } } = arg0
433+ assert % T { data: % Expr { op: :parameter , args: [ 1 ] } } = arg1
434+
435+ # because things are balanced, we don't know which of side1 and side2 are left and right
436+ # in our expr, so we should disambiguate:
437+
438+ { [ % Stage { } = left ] , [ % Stage { } = right ] } =
439+ Enum . split_with ( [ side1 , side2 ] , fn % Stage { expr: { expr } } -> expr . data . op == :multiply end )
440+
441+ # left should depend on exactly the same parameters as the root, as it's pulling from
442+ # the global scope
443+ assert { % T { data: % Expr { op: :multiply , args: [ x , arg1_left ] } } } = left . expr
444+
445+ assert % T {
446+ data: % Expr {
447+ op: :metadata ,
448+ args: [
449+ % T { data: % Expr { id: x_left_id , op: :parameter , args: [ 1 ] } } ,
450+ % { split: true }
451+ ]
452+ }
453+ } = x
454+
455+ assert % T { data: % Expr { id: arg1_left_id , op: :parameter , args: [ 0 ] } } = arg1_left
456+
457+ assert left . argument_sources [ arg1_left_id ] == { nil , 1 }
458+ assert left . argument_sources [ x_left_id ] == { root . id , 0 }
459+
460+ # right should depend on the result of the root and on arg1, but arg1 will be reindexed
461+ # we should assert that the argument source for arg1_right is correct
462+ assert { % T { data: % Expr { op: :divide , args: [ x , arg1_right ] } } } = right . expr
463+
464+ assert % T {
465+ data: % Expr {
466+ op: :metadata ,
467+ args: [
468+ % T { data: % Expr { id: x_right_id , op: :parameter , args: [ 1 ] } } ,
469+ % { split: true }
470+ ]
471+ }
472+ } = x
473+
474+ assert % T { data: % Expr { id: arg1_right_id , op: :parameter , args: [ 0 ] } } = arg1_right
475+
476+ assert right . argument_sources [ arg1_right_id ] == { nil , 1 }
477+ assert right . argument_sources [ x_right_id ] == { root . id , 0 }
478+
479+ assert % T { data: % Expr { op: :add , args: [ w_right , w_left ] } } = merge . expr
480+
481+ assert % T {
482+ data: % Expr {
483+ op: :metadata ,
484+ args: [
485+ % T { data: % Expr { id: w_right_id , op: :parameter , args: [ 0 ] } } ,
486+ % { split: true }
487+ ]
488+ }
489+ } = w_right
490+
491+ assert % T {
492+ data: % Expr {
493+ op: :metadata ,
494+ args: [
495+ % T { data: % Expr { id: w_left_id , op: :parameter , args: [ 1 ] } } ,
496+ % { split: true }
497+ ]
498+ }
499+ } = w_left
500+
501+ assert merge . argument_sources [ w_right_id ] == { right . id , 0 }
502+ assert merge . argument_sources [ w_left_id ] == { left . id , 0 }
503+
504+ assert GraphSplitter . run ( chain , args ) == expected_result
505+ end
506+ end
398507end
0 commit comments