@@ -486,15 +486,16 @@ defmodule Axon.Compiler do
486486 name: name_fn ,
487487 opts: [ shape: _input_shape , optional: optional? ]
488488 } ,
489- _nodes ,
489+ nodes ,
490490 { cache , op_counts , block_cache , model_state_meta } ,
491491 % { mode: mode , print_values: print_values }
492492 ) do
493493 name = name_fn . ( :input , op_counts )
494494 op_counts = Map . update ( op_counts , :input , 1 , fn x -> x + 1 end )
495+ all_inputs = get_all_inputs ( nodes )
495496
496497 predict_fun = fn _params , inputs , state , _cache , result_cache , _fn_stacktrace ->
497- value = get_input ( inputs , name , optional? )
498+ value = get_input ( all_inputs , inputs , name , optional? )
498499
499500 # TODO: Add this back in
500501 # validate_input_shape!(value, shape)
@@ -509,7 +510,7 @@ defmodule Axon.Compiler do
509510 end
510511
511512 init_fun = fn template , _cache , result_cache , _fn_stacktrace , _keys ->
512- input = get_input ( template , name , optional? )
513+ input = get_input ( all_inputs , template , name , optional? )
513514 { Nx . to_template ( input ) , { % { } , result_cache } }
514515 end
515516
@@ -889,16 +890,32 @@ defmodule Axon.Compiler do
889890 { id , model_funs , cache , op_counts , block_cache , model_state_meta }
890891 end
891892
892- defp get_input ( inputs , name , optional? ) do
893+ defp get_all_inputs ( nodes ) do
894+ nodes
895+ |> Enum . filter ( fn { _ , % { op: op } } -> op == :input end )
896+ |> Enum . map ( fn { _ , % { name: name_fn } } ->
897+ # inputs require a name, so we can just ignore op counts
898+ name_fn . ( :input , % { } )
899+ end )
900+ |> Enum . uniq ( )
901+ end
902+
903+ defp get_input ( all_input_names , inputs , name , optional? ) do
893904 res =
894- case inputs do
895- % Nx.Tensor { } = inputs ->
905+ case { all_input_names , inputs } do
906+ { [ ^ name ] , % Nx.Tensor { } = inputs } ->
896907 inputs
897908
898- % { } = inputs ->
909+ { _ , % Nx.Tensor { } } ->
910+ raise ArgumentError ,
911+ "ambiguous input given to the model," <>
912+ " expected inputs with names #{ inspect ( all_input_names ) } " <>
913+ " but received a single tensor as input"
914+
915+ { _ , % { } = inputs } ->
899916 inputs [ name ]
900917
901- inputs when is_tuple ( inputs ) ->
918+ { [ ^ name ] , inputs } when is_tuple ( inputs ) ->
902919 inputs
903920
904921 _ ->
0 commit comments