@@ -15,9 +15,7 @@ defmodule NxIREE.Compiler do
1515 @ behaviour Nx.Defn.Compiler
1616
1717 @ impl true
18- def __compile__ ( key , vars , fun , opts ) do
19- output_container = fun . ( vars )
20-
18+ def __compile__ ( _key , vars , fun , opts ) do
2119 { iree_compiler_flags , opts } = Keyword . pop ( opts , :iree_compiler_flags , nil )
2220 { iree_runtime_options , opts } = Keyword . pop ( opts , :iree_runtime_options , [ ] )
2321 { output_mode , opts } = Keyword . pop ( opts , :output_mode , nil )
@@ -26,20 +24,22 @@ defmodule NxIREE.Compiler do
2624 raise "missing :iree_compiler_flags option"
2725 end
2826
29- mlir_module = EXLA . to_mlir_module ( key , vars , opts )
27+ % { mlir_module: mlir_module , output_container: output_container , used_inputs: used_inputs } =
28+ EXLA . to_mlir_module ( fun , vars , Keyword . put ( opts , :within_defn_compiler , true ) )
3029
3130 bytecode = NxIREE . compile ( mlir_module , iree_compiler_flags )
3231
3332 if output_mode == :bytecode do
3433 throw ( { :bytecode , % { bytecode: bytecode , output_container: output_container } } )
3534 else
3635 fn [ inputs ] ->
36+ filtered_inputs =
37+ filter_inputs_by_indices ( inputs , used_inputs )
38+
3739 { :ok , results } =
3840 NxIREE . call (
3941 bytecode ,
40- Enum . map ( inputs , fn f ->
41- f . ( )
42- end ) ,
42+ filtered_inputs ,
4343 iree_runtime_options
4444 )
4545
@@ -68,4 +68,17 @@ defmodule NxIREE.Compiler do
6868
6969 @ impl true
7070 defdelegate __to_backend__ ( opts ) , to: EXLA.Defn
71+
72+ defp filter_inputs_by_indices ( args , inputs ) do
73+ filter_by_indices_list ( args , 0 , Enum . sort ( inputs ) , fn x , _ -> x end )
74+ end
75+
76+ defp filter_by_indices_list ( [ var | vars ] , i , [ i | inputs ] , callback ) ,
77+ do: [ callback . ( var , i ) | filter_by_indices_list ( vars , i + 1 , inputs , callback ) ]
78+
79+ defp filter_by_indices_list ( [ _var | vars ] , i , inputs , callback ) ,
80+ do: filter_by_indices_list ( vars , i + 1 , inputs , callback )
81+
82+ defp filter_by_indices_list ( [ ] , _i , [ ] , _callback ) ,
83+ do: [ ]
7184end
0 commit comments