From bb8a6f42b89b090c9b134d5710157606e9b99494 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 11 Sep 2019 04:18:45 +0000 Subject: [PATCH] julia: fix `mx.forward` kwargs checking close https://github.com/dmlc/MXNet.jl/issues/431 --- julia/src/executor.jl | 2 +- julia/test/unittest/bind.jl | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/julia/src/executor.jl b/julia/src/executor.jl index e565617976ce..37f2dde615b8 100644 --- a/julia/src/executor.jl +++ b/julia/src/executor.jl @@ -176,7 +176,7 @@ end function forward(self::Executor; is_train::Bool = false, kwargs...) for (k,v) in kwargs - @assert(k ∈ self.arg_dict, "Unknown argument $k") + @assert(k ∈ keys(self.arg_dict), "Unknown argument $k") @assert(isa(v, NDArray), "Keyword argument $k must be an NDArray") copy!(self.arg_dict[k], v) end diff --git a/julia/test/unittest/bind.jl b/julia/test/unittest/bind.jl index 0ae0ab427b99..a221733cded1 100644 --- a/julia/test/unittest/bind.jl +++ b/julia/test/unittest/bind.jl @@ -84,11 +84,26 @@ function test_arithmetic() end end +function test_forward() + # forward with data keyword argument + x = @var x + y = x .+ 42 + + A = 1:5 + B = A .+ 42 + + e = bind(y, args = Dict(:x => NDArray(24:28))) + z = forward(e, x = NDArray(A))[1] + + @test copy(z) == collect(B) +end + ################################################################################ # Run tests ################################################################################ @testset "Bind Test" begin test_arithmetic() + test_forward() end end