@@ -486,15 +486,16 @@ defmodule Axon.Compiler do
486
486
name: name_fn ,
487
487
opts: [ shape: _input_shape , optional: optional? ]
488
488
} ,
489
- _nodes ,
489
+ nodes ,
490
490
{ cache , op_counts , block_cache , model_state_meta } ,
491
491
% { mode: mode , print_values: print_values }
492
492
) do
493
493
name = name_fn . ( :input , op_counts )
494
494
op_counts = Map . update ( op_counts , :input , 1 , fn x -> x + 1 end )
495
+ all_inputs = get_all_inputs ( nodes )
495
496
496
497
predict_fun = fn _params , inputs , state , _cache , result_cache , _fn_stacktrace ->
497
- value = get_input ( inputs , name , optional? )
498
+ value = get_input ( all_inputs , inputs , name , optional? )
498
499
499
500
# TODO: Add this back in
500
501
# validate_input_shape!(value, shape)
@@ -509,7 +510,7 @@ defmodule Axon.Compiler do
509
510
end
510
511
511
512
init_fun = fn template , _cache , result_cache , _fn_stacktrace , _keys ->
512
- input = get_input ( template , name , optional? )
513
+ input = get_input ( all_inputs , template , name , optional? )
513
514
{ Nx . to_template ( input ) , { % { } , result_cache } }
514
515
end
515
516
@@ -889,16 +890,32 @@ defmodule Axon.Compiler do
889
890
{ id , model_funs , cache , op_counts , block_cache , model_state_meta }
890
891
end
891
892
892
- defp get_input ( inputs , name , optional? ) do
893
+ defp get_all_inputs ( nodes ) do
894
+ nodes
895
+ |> Enum . filter ( fn { _ , % { op: op } } -> op == :input end )
896
+ |> Enum . map ( fn { _ , % { name: name_fn } } ->
897
+ # inputs require a name, so we can just ignore op counts
898
+ name_fn . ( :input , % { } )
899
+ end )
900
+ |> Enum . uniq ( )
901
+ end
902
+
903
+ defp get_input ( all_input_names , inputs , name , optional? ) do
893
904
res =
894
- case inputs do
895
- % Nx.Tensor { } = inputs ->
905
+ case { all_input_names , inputs } do
906
+ { [ ^ name ] , % Nx.Tensor { } = inputs } ->
896
907
inputs
897
908
898
- % { } = inputs ->
909
+ { _ , % Nx.Tensor { } } ->
910
+ raise ArgumentError ,
911
+ "ambiguous input given to the model," <>
912
+ " expected inputs with names #{ inspect ( all_input_names ) } " <>
913
+ " but received a single tensor as input"
914
+
915
+ { _ , % { } = inputs } ->
899
916
inputs [ name ]
900
917
901
- inputs when is_tuple ( inputs ) ->
918
+ { [ ^ name ] , inputs } when is_tuple ( inputs ) ->
902
919
inputs
903
920
904
921
_ ->
0 commit comments