-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@mul gives Array{TrackedReal} instead of TrackedArray #8
Comments
For future reference, this is a workaround
|
This use of It would not be super-hard to add a gradient defn. for this function. It would be better still to call someone else’s function, more in the spirit of this package just being the front-end. For CuArrays there are special kernels for doing such things. But making all of this work nicely in Julia is work in progress... Relevant links: That workaround is precisely this, which should be correct but won’t be fast:
|
I discovered that OMEinsum now supports batch matmul, as of under-Peter/OMEinsum.jl#74. So at least with Zygote your example can now work: using Zygote, Random, OMEinsum#master
Random.seed!(42);
Zd_ = randn(4,3,2); M_ = randn(3,2);
f(Zd, M) = (@ein E[a,d] := Zd[a,b,d]*M[b,d]; sum(exp.(E)))
Zygote.gradient(f, Zd_, M_)
using TensorCast#master # @mul doesn't handle this, but the @reduce fallback does:
g(Zd, M) = (@reduce E[a,d] := sum(b) Zd[a,b,d]*M[b,d]; sum(exp.(E)))
Zygote.gradient(g, Zd_, M_) # agrees! Surprisingly you have to get to quite large arrays for its gradient to be faster than the
Anyway I'm going to close this essentially as won't-fix, the relevant bit of |
Similar to mcabbott/SliceMap.jl#3
The code below produces an array of tracked reals instead of a tracked array. If the arrays are CuArrays, then the code fails since there is no constructor for
CuArray(::Array{TrackedReal})
The text was updated successfully, but these errors were encountered: