Skip to content

Commit cc7dec6

Browse files
authored
Raise on ambiguous inputs (#599)
1 parent ce2e247 commit cc7dec6

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

lib/axon/compiler.ex

+25-8
Original file line numberDiff line numberDiff line change
@@ -486,15 +486,16 @@ defmodule Axon.Compiler do
486486
name: name_fn,
487487
opts: [shape: _input_shape, optional: optional?]
488488
},
489-
_nodes,
489+
nodes,
490490
{cache, op_counts, block_cache, model_state_meta},
491491
%{mode: mode, print_values: print_values}
492492
) do
493493
name = name_fn.(:input, op_counts)
494494
op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end)
495+
all_inputs = get_all_inputs(nodes)
495496

496497
predict_fun = fn _params, inputs, state, _cache, result_cache, _fn_stacktrace ->
497-
value = get_input(inputs, name, optional?)
498+
value = get_input(all_inputs, inputs, name, optional?)
498499

499500
# TODO: Add this back in
500501
# validate_input_shape!(value, shape)
@@ -509,7 +510,7 @@ defmodule Axon.Compiler do
509510
end
510511

511512
init_fun = fn template, _cache, result_cache, _fn_stacktrace, _keys ->
512-
input = get_input(template, name, optional?)
513+
input = get_input(all_inputs, template, name, optional?)
513514
{Nx.to_template(input), {%{}, result_cache}}
514515
end
515516

@@ -889,16 +890,32 @@ defmodule Axon.Compiler do
889890
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
890891
end
891892

892-
defp get_input(inputs, name, optional?) do
893+
defp get_all_inputs(nodes) do
894+
nodes
895+
|> Enum.filter(fn {_, %{op: op}} -> op == :input end)
896+
|> Enum.map(fn {_, %{name: name_fn}} ->
897+
# inputs require a name, so we can just ignore op counts
898+
name_fn.(:input, %{})
899+
end)
900+
|> Enum.uniq()
901+
end
902+
903+
defp get_input(all_input_names, inputs, name, optional?) do
893904
res =
894-
case inputs do
895-
%Nx.Tensor{} = inputs ->
905+
case {all_input_names, inputs} do
906+
{[^name], %Nx.Tensor{} = inputs} ->
896907
inputs
897908

898-
%{} = inputs ->
909+
{_, %Nx.Tensor{}} ->
910+
raise ArgumentError,
911+
"ambiguous input given to the model," <>
912+
" expected inputs with names #{inspect(all_input_names)}" <>
913+
" but received a single tensor as input"
914+
915+
{_, %{} = inputs} ->
899916
inputs[name]
900917

901-
inputs when is_tuple(inputs) ->
918+
{[^name], inputs} when is_tuple(inputs) ->
902919
inputs
903920

904921
_ ->

test/axon/compiler_test.exs

+14
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ defmodule CompilerTest do
128128
assert message =~ "exception found when compiling layer Axon.Layers.add/2 named add_0"
129129
assert message =~ "cannot broadcast tensor of dimensions {1, 32} to {1, 64}"
130130
end
131+
132+
test "raises if inputs are ambiguous" do
133+
x = Axon.input("x")
134+
y = Axon.input("y")
135+
model = Axon.add(x, y)
136+
137+
{_, predict_fn} = Axon.build(model)
138+
139+
exception = assert_raise ArgumentError, fn ->
140+
predict_fn.(ModelState.empty(), Nx.tensor([1]))
141+
end
142+
143+
assert Exception.message(exception) =~ "ambiguous"
144+
end
131145
end
132146

133147
describe "optional" do

test/axon/loop_test.exs

+5-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ defmodule Axon.LoopTest do
132132
Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam)
133133

134134
assert %{model_state: %{}} =
135-
pstate = init_fn.({Nx.tensor([[2]]), Nx.tensor([[2]])}, Axon.ModelState.empty())
135+
pstate =
136+
init_fn.(
137+
{%{"input_0" => Nx.tensor([[2]]), "input_1" => Nx.tensor([[2]])}, Nx.tensor(0)},
138+
Axon.ModelState.empty()
139+
)
136140

137141
state = %State{step_state: pstate}
138142

0 commit comments

Comments
 (0)