Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched_transpose causes a 'Need an adjoint for constructor NNlib.BatchedTranspose' error #552

Open
nirmal-suthar opened this issue Mar 21, 2020 · 1 comment

Comments

@nirmal-suthar
Copy link
Contributor

I am getting this error whenever I use NNlib.batched_transpose(x) in the model. I get the error while backpropagation, forward propagation is running fine.

ERROR: LoadError: Need an adjoint for constructor NNlib.BatchedTranspose{Float64,Array{Float64,3}}. Gradient is of type Array{Float64,3}

@mcabbott
Copy link
Member

MWE is like this?

julia> using Zygote, NNlib

julia> gradient((x,y) -> sum(batched_mul(PermutedDimsArray(x,(2,1,3)), y)), rand(2,2,2), rand(2,2,2))[1] |> summary
"2×2×2 PermutedDimsArray(::Array{Float64,3}, (2, 1, 3)) with eltype Float64"

julia> gradient((x,y) -> sum(batched_mul(batched_transpose(x), y)), rand(2,2,2), rand(2,2,2))[1] |> summary
ERROR: Need an adjoint for constructor NNlib.BatchedTranspose{Float64,Array{Float64,3}}. Gradient is of type Array{Float64,3}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.Jnew{NNlib.BatchedTranspose{Float64,Array{Float64,3}},Nothing,false})(::Array{Float64,3}) at /Users/me/.julia/dev/Zygote/src/lib/lib.jl:294
 [3] (::Zygote.var"#378#back#196"{Zygote.Jnew{NNlib.BatchedTranspose{Float64,Array{Float64,3}},Nothing,false}})(::Array{Float64,3}) at /Users/me/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] BatchedTranspose at /Users/me/.julia/packages/NNlib/FAI3o/src/batched/batchedadjtrans.jl:22 [inlined]
 [5] (::typeof((NNlib.BatchedTranspose{Float64,Array{Float64,3}})))(::Array{Float64,3}) at /Users/me/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [6] BatchedTranspose at /Users/me/.julia/packages/NNlib/FAI3o/src/batched/batchedadjtrans.jl:40 [inlined]
 [7] (::typeof((NNlib.BatchedTranspose)))(::Array{Float64,3}) at /Users/me/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [8] batched_transpose at /Users/me/.julia/packages/NNlib/FAI3o/src/batched/batchedadjtrans.jl:26 [inlined]
 [9] (::typeof((batched_transpose)))(::Array{Float64,3}) at /Users/me/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [10] #9 at ./REPL[6]:1 [inlined]
 [11] (::typeof((#9)))(::Float64) at /Users/me/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#38#39"{typeof((#9))})(::Float64) at /Users/me/.julia/dev/Zygote/src/compiler/interface.jl:36
 [13] gradient(::Function, ::Array{Float64,3}, ::Vararg{Array{Float64,3},N} where N) at /Users/me/.julia/dev/Zygote/src/compiler/interface.jl:45
 [14] top-level scope at REPL[6]:1

The wrappers NNlib.BatchedTranspose would need more work to become general-purpose. Perhaps we should just adjust batched_mul to accept this PermutedDimsArray, without going to the fallback method?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants