-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation for batch-wise matrix multiplication #100
Conversation
There is also a version for GPU, but looks like it should be in CuArrays. |
I like the idea of this but cc @rogerluo. Since this is his code originally it'd probably be good to attribute or perhaps set him as a commit author (whatever he's happy with). Implementation wise I'm unsure why this needs to call BLAS directly, it seems like you could use |
@MikeInnes I do talk to @Roger-luo on the #machine-learning channel on slack. And he said that he is planning to implement a jax-like batch broadcast function and he prefered to put all the batched operation in a single package, so he didn't submit any PR for this. This is also the reason I make this PR. Then, I just set Roger-luo as the commit author. About the BLAS, I guess there are some performance difference? |
Hi, guys, just go ahead, it was released under MIT. I'll work on something more general (a jax-like batch broadcast), and consider whether make a PR here or put it to a single package (depends on how large is it). |
I used BLAS directly because I need to add offset directly to I profiled Maybe just use it first, since it works at least, and replace that with a more elegant batched broadcast in the near future? @MikeInnes PS. MKL provides some batched routines as well, which I think has better performance on Intel CPUs. |
Ok seems fair. The other thing is it seems like it'd be nice to make use of the |
Yes, I agree, @chengchingwen if you want to wrap the routine to Julia's Just treat rank-3 tensor's last dimension as batch dimension, and add prefix As for |
Oh, yeah, that's true. So either ignore my |
@Roger-luo What's the different between |
Yes, (maybe not here? This is one of the reason I didn't make the PR, since there will be type pirate if they are overloaded. But there's no type piracy in Batched since there's Another solution is to define |
@Roger-luo I make a little modify to |
B = randn(5,7,3) | ||
C = randn(7,6,3) | ||
|
||
@test batched_mul(A, B) == bmm_test(A, B) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some tests for adjoint
and ComplexF64
?
Yes, it only makes sense to overload It looks good in general to me, but could you add some tests for Any further comments? @MikeInnes |
@Roger-luo Should I also made the gradient function for |
It would be great if you could, it is actually just straight forward: do the sample operation to the incoming gradients, Or you could read how this is implemented in Tracker
|
@Roger-luo I can't find the definition of the |
I don't think you need trim here. Yes it's the same as forward pass. |
src/batchedmul.jl
Outdated
|
||
|
||
function ∇batched_mul(Δ::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T | ||
(batched_mul(Δ, batched_transpose(B)), batched_mul(A, Δ)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need these extra methods? Don't you need to do batched_transpose(A)
to unwrap A
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These extra methods are provided for defining the gradient function, maybe it should be in NNlib.jl? And Yes, the implementation is not correct, I forgot that I'm using BatchedTranspose
. I need to check that again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but I'm suggesting that only a single gradient definition (the above) is needed, and these ones shouldn't be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about that, the gradient function is completely wrong. The reason it need several definition is because it's not always Δ
multiply A
or B
. When A
is transposed, it would B
multiply Δ
. I'll push the correct version later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still doesn't look correct. You can replace (delta*b')
with (b*delta')'
but you're not currently doing the transpose of the output.
In any case, since the original definition (delta*b')
would calculate the same thing, why have the transposed version at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Becuase I didn't define the gradient for batched_transpose
and batched_adjoint
.
I'll remove the transposed version. then do I need to define the gradient for batched_transpose
here? or I just leave it to Flux( or Tracker)?
@MikeInnes this should be correct, for the |
bump |
I just realized we don't need |
@MikeInnes Do we miss something? |
Bump, this would be nice to have. It might be worth adding (although these could be a later PR):
|
Codecov Report
@@ Coverage Diff @@
## master #100 +/- ##
==========================================
- Coverage 79.03% 76.95% -2.08%
==========================================
Files 24 26 +2
Lines 763 820 +57
==========================================
+ Hits 603 631 +28
- Misses 160 189 +29
Continue to review full report at Codecov.
|
On CuArrays#master I get this error. julia> batched_mul(cu(rand(10,10,3)), cu(rand(10,10,3)))
ERROR: conversion to pointer not defined for CuArray{Float32,3}
Stacktrace:
[1] unsafe_convert(::Type{Ptr{Float32}}, ::CuArray{Float32,3}) at .\pointer.jl:67
[2] batched_mul!(::CuArray{Float32,3}, ::CuArray{Float32,3}, ::CuArray{Float32,3}) at C:\Users\jules\.julia\dev\NNlib\src\gemm.jl:82
[3] batched_mul(::CuArray{Float32,3}, ::CuArray{Float32,3}) at C:\Users\jules\.julia\dev\NNlib\src\batchedmul.jl:10
[4] top-level scope at none:0 |
@merckxiaan CUDA is not supported in this PR. if you need CUDA check https://github.com/Roger-luo/CuBatchedRoutines.jl and just an update to whoever care about batched operations. IMO I don't think the general solution for batched operations will go in NNlib as PRs eventually, it would make this package too complicated. Since they should be supported by things like Hydra.jl as primitive routines. and current implementation of @maleadt also had some idea on this after we discussed during JuliaCon as well. I also have a gist for this kind of thing that generate generic SPMD operation code (more specific on machine learning task probably, it is less general than Hydra), but stuck at this issue of IRTools for now(FluxML/IRTools.jl#21), but I think I'll need to wait until summer to have sometime to wrap these things up and make it stable. And I'm not sure what's the status of Hydra so far, it seems to be abandoned. But feel free to take whatever in BatchedRoutines/CuBatchedRoutines for now (in principal we need this PR be merged then commit to CuArrays to support |
actually general speaking, fallback can be directly supported by transform the
I think we just need to rewrite all the dispatch similar to another idea I had is actually to generate these SPMD instructions directly from ASTs and save them as scripts instead of transform SSA IR during runtime, it seems to be harder, but should get rid of the performance issue easily and can be precompiled. Most routines do not require any runtime information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks generally good, thanks! Perhaps all this code should just go in a batched/
folder though, rather than having the batch*.jl
files.
src/batchedadjtrans.jl
Outdated
|
||
Base.parent(A::BatchedAdjOrTrans) = A.parent | ||
|
||
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems oddly specific (vs e.g. broadcast
). Is this how base handles it -- or is there another reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I try to follow the style in the LinearAlgebra stdlib then.
Btw, this reminds me do we need a BatchedMat
and BatchedVec
type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's how base handle it, and I think we don't need to handle too much cases for these two types. It's mainly defined for batched routines to implement pullback instead of a generic batched array as an intermediate solution.
I'd like to make it as simple as possible here, and treat Array as the batched type directly and define more generic solution in a separate package in the future.
Or you will end up doing the same thing in Batched (and you'll find defining these two are not enough), then why not just use its early version? Which defined more types and should just work with existing pullbacks.
@chengchingwen is this ready for merge? @MikeInnes could we get this merged? it has been open for a very long time and the functionality here is quite useful |
@CarloLucibello yes, I think it’s ready for merge. |
Bump. I need this |
|
||
include("./batchedadjtrans.jl") | ||
|
||
function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a docstring here?
Should |
let's merge this, we can wait a bit before exporting it |
thanks! |
531: Gradient for batched_mul r=CarloLucibello a=mcabbott Adds a gradient definition for FluxML/NNlib.jl#100 Co-authored-by: Michael Abbott <me@escbook>
619: NNlib batched_mul! r=maleadt a=mcabbott This hooks up FluxML/NNlib.jl#100 for CuArrays. I've added an extremely simple test, but could find the right place to insert this into test/blas.jl perhaps if that would be better? Tests fail locally, before this PR, but I'm not sure why. I think that `_GemmFloat == CUBLAS.CublasFloat`, except that this isn't loaded yet. Is there a neater way? Co-authored-by: Michael Abbott <me@pseudomac>
I port the
batched_gemm!
from BatchedRoutines.jl and make some wrapper for doing batch-wise matrix multiplication.