-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to batched_mul, including PermutedDimsArray #187
Conversation
This has changed a lot. The current form uses https://github.com/JuliaMatrices/ArrayLayouts.jl to keep track dimensions & strides, via traits See also #191 for an alternative approach. Either of these can be made to work with JuliaGPU/CuArrays.jl#664, which will similarly allow permutations. |
This could possibly now be done with https://github.com/SciML/ArrayInterface.jl instead of https://github.com/JuliaMatrices/ArrayLayouts.jl . But I'm hesitant to add dependencies (which would probably need to be added to CUDA.jl too), and I think #191 is simpler lower-tech solution. |
I think that we can allow
batched_gemm!
to be called on many (but not all)PermutedDimsArray
s, and this is generally much faster than first callingpermutedims
. This PR is an attempt to implement this. It also extendsbatched_mul!
to takeα, β
scales likemul!
.EARLIER:
PermutedDimsArray{<:Number,3,(2,1,3)}
as equivalent toBatchedTranspose
(see batched_transpose causes a 'Need an adjoint for constructor NNlib.BatchedTranspose' error Zygote.jl#552) and to allowbatched_adjoint ∘ batched_transpose
to be trivial on real-valued arrays.copy(::BatchedAdjoint)
etc, from implementation for batch-wise matrix multiplication #100, to return anArray
likecopy(::Adjoint)
.