Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation for batch-wise matrix multiplication #100

Merged
merged 9 commits into from
Feb 28, 2020

Conversation

chengchingwen
Copy link
Member

I port the batched_gemm! from BatchedRoutines.jl and make some wrapper for doing batch-wise matrix multiplication.

@chengchingwen
Copy link
Member Author

There is also a version for GPU, but looks like it should be in CuArrays.

@MikeInnes
Copy link
Member

I like the idea of this but cc @rogerluo. Since this is his code originally it'd probably be good to attribute or perhaps set him as a commit author (whatever he's happy with).

Implementation wise I'm unsure why this needs to call BLAS directly, it seems like you could use mul! on views or something, but perhaps he can comment.

@chengchingwen
Copy link
Member Author

@MikeInnes I do talk to @Roger-luo on the #machine-learning channel on slack. And he said that he is planning to implement a jax-like batch broadcast function and he prefered to put all the batched operation in a single package, so he didn't submit any PR for this. This is also the reason I make this PR. Then, I just set Roger-luo as the commit author.

About the BLAS, I guess there are some performance difference?

@Roger-luo
Copy link
Contributor

Hi, guys, just go ahead, it was released under MIT. I'll work on something more general (a jax-like batch broadcast), and consider whether make a PR here or put it to a single package (depends on how large is it).

@Roger-luo
Copy link
Contributor

Roger-luo commented Mar 27, 2019

I used BLAS directly because I need to add offset directly to Arrays location, which assumes the array is a strided batch array and the memory is contiguous on batch dimension.

I profiled view before, view do not have assumption that each of the element in a batch is contiguous in memory, thus there's some obvious performance overhead. And that was why I move on to batched types. Then I found it can actually be general for this kind of calculation (which is the batch broadcast). I thinks this version of implementation is the fastest (tho it's ugly) at the moment according to a discourse post (some asked for batched gemm, and there's several different implementation I believe).

Maybe just use it first, since it works at least, and replace that with a more elegant batched broadcast in the near future? @MikeInnes

PS. MKL provides some batched routines as well, which I think has better performance on Intel CPUs.

@MikeInnes
Copy link
Member

Ok seems fair. The other thing is it seems like it'd be nice to make use of the Transpose types, rather than keyword arguments; then you probably won't need an explicit gradient function.

@Roger-luo
Copy link
Contributor

Roger-luo commented Mar 27, 2019

Yes, I agree, @chengchingwen if you want to wrap the routine to Julia's mul! form, you should read the mul! part for Array. Or I also have something related: https://github.com/Roger-luo/Batched.jl/blob/master/src/matmul.jl

Just treat rank-3 tensor's last dimension as batch dimension, and add prefix batch to every function related to batch then you don't need a type for this kind of array.

As for Transpose, this doesn't work for rank-3 tensor, that's why I implemented several new array types for batched arrays. @MikeInnes I guess we will need to have a BatchedTranspose and BatchedAdjoint here, @chengchingwen could you move them from Batched.jl ?

@MikeInnes
Copy link
Member

Oh, yeah, that's true. So either ignore my Transpose comment or move over BatchedTranspose.

@chengchingwen
Copy link
Member Author

@Roger-luo What's the different between BatchedTranspose and BatchedAdjoint? Should I move both of them?

@Roger-luo
Copy link
Contributor

Roger-luo commented Mar 27, 2019

Yes, transpose is not the same with adjoint when it's complex valued. And remember to overload LinearAlgebra.tranpose and LinearAlgebra.adjoint

(maybe not here? This is one of the reason I didn't make the PR, since there will be type pirate if they are overloaded. But there's no type piracy in Batched since there's BatchedArray. @MikeInnes )

Another solution is to define batched_transpose and batched_adjoint here, and remember to implement their gradients.

@chengchingwen
Copy link
Member Author

@Roger-luo I make a little modify to BatchedTranspose as a subtype of AbstractArray{T, 3} since we only have a batched_gemm! for AbstractArray{T,3}. I think it's not a good idea to overload LinearAlgebra.transpose, so I just export the batched_transpose.

B = randn(5,7,3)
C = randn(7,6,3)

@test batched_mul(A, B) == bmm_test(A, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some tests for adjoint and ComplexF64?

@Roger-luo
Copy link
Contributor

Roger-luo commented Apr 2, 2019

Yes, it only makes sense to overload LinearAlgebra.transpose when you have BatchedArray type. Let's do it in simple way here first.

It looks good in general to me, but could you add some tests for batched_adjoint as well? Since there's a bunch of related functions there. Then this should be good to go.

Any further comments? @MikeInnes

@chengchingwen
Copy link
Member Author

@Roger-luo Should I also made the gradient function for BatchedAdjoint? but I don't really know what the gradient would be when the number is complex

@Roger-luo
Copy link
Contributor

It would be great if you could, it is actually just straight forward: do the sample operation to the incoming gradients, batched_transpose(delta) and batched_adjoint(delta)

Or you could read how this is implemented in Tracker

https://github.com/FluxML/Tracker.jl/blob/38e3a886b4df9c5b768e12f94ebb3148cb7019a2/src/lib/array.jl#L126

adjoint simply means to transpose and conjugate.

@chengchingwen
Copy link
Member Author

@Roger-luo I can't find the definition of the trim function, but looks like the gradient for adjoint behave almost the same as transpose?

@Roger-luo
Copy link
Contributor

I don't think you need trim here. Yes it's the same as forward pass.



function ∇batched_mul(Δ::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T
(batched_mul(Δ, batched_transpose(B)), batched_mul(A, Δ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need these extra methods? Don't you need to do batched_transpose(A) to unwrap A?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These extra methods are provided for defining the gradient function, maybe it should be in NNlib.jl? And Yes, the implementation is not correct, I forgot that I'm using BatchedTranspose. I need to check that again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but I'm suggesting that only a single gradient definition (the above) is needed, and these ones shouldn't be necessary.

Copy link
Member Author

@chengchingwen chengchingwen Apr 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, the gradient function is completely wrong. The reason it need several definition is because it's not always Δ multiply A or B. When A is transposed, it would B multiply Δ. I'll push the correct version later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still doesn't look correct. You can replace (delta*b') with (b*delta')' but you're not currently doing the transpose of the output.

In any case, since the original definition (delta*b') would calculate the same thing, why have the transposed version at all?

Copy link
Member Author

@chengchingwen chengchingwen Apr 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Becuase I didn't define the gradient for batched_transpose and batched_adjoint.
I'll remove the transposed version. then do I need to define the gradient for batched_transpose here? or I just leave it to Flux( or Tracker)?

@chengchingwen
Copy link
Member Author

@MikeInnes this should be correct, for the transpose one.

@chengchingwen
Copy link
Member Author

bump

@chengchingwen
Copy link
Member Author

I just realized we don't need ∇batched_mul at all, so I removed it.

@chengchingwen
Copy link
Member Author

@MikeInnes Do we miss something?

@mcabbott
Copy link
Member

mcabbott commented Sep 25, 2019

Bump, this would be nice to have. It might be worth adding (although these could be a later PR):

  • A fallback implementation using veiws and mul!, which would work for e.g. dual numbers.
  • The ability to apply constants, like mul!(Y,A,B,α,β) will do soon (and gemm! already does).

@codecov-io
Copy link

codecov-io commented Sep 26, 2019

Codecov Report

Merging #100 into master will decrease coverage by 2.07%.
The diff coverage is 65.38%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #100      +/-   ##
==========================================
- Coverage   79.03%   76.95%   -2.08%     
==========================================
  Files          24       26       +2     
  Lines         763      820      +57     
==========================================
+ Hits          603      631      +28     
- Misses        160      189      +29
Impacted Files Coverage Δ
src/NNlib.jl 100% <ø> (ø) ⬆️
src/batched/batchedadjtrans.jl 38.46% <38.46%> (ø)
src/batched/batchedmul.jl 83.33% <83.33%> (ø)
src/gemm.jl 95.23% <95%> (-4.77%) ⬇️
src/softmax.jl 27.02% <0%> (-22.98%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 342928e...83e359c. Read the comment docs.

@jumerckx
Copy link
Contributor

jumerckx commented Oct 6, 2019

On CuArrays#master I get this error.

julia> batched_mul(cu(rand(10,10,3)), cu(rand(10,10,3)))
ERROR: conversion to pointer not defined for CuArray{Float32,3}
Stacktrace:
 [1] unsafe_convert(::Type{Ptr{Float32}}, ::CuArray{Float32,3}) at .\pointer.jl:67
 [2] batched_mul!(::CuArray{Float32,3}, ::CuArray{Float32,3}, ::CuArray{Float32,3}) at C:\Users\jules\.julia\dev\NNlib\src\gemm.jl:82
 [3] batched_mul(::CuArray{Float32,3}, ::CuArray{Float32,3}) at C:\Users\jules\.julia\dev\NNlib\src\batchedmul.jl:10
 [4] top-level scope at none:0

@Roger-luo
Copy link
Contributor

Roger-luo commented Oct 6, 2019

@merckxiaan CUDA is not supported in this PR. if you need CUDA check https://github.com/Roger-luo/CuBatchedRoutines.jl

and just an update to whoever care about batched operations.

IMO I don't think the general solution for batched operations will go in NNlib as PRs eventually, it would make this package too complicated. Since they should be supported by things like Hydra.jl as primitive routines. and current implementation of BatchedRoutines didn't consider things like small matrices/vectors large batch, which may cause a slow down around 10~15x, I do have these implementation in my research code, but it's a bit messy now.

@maleadt also had some idea on this after we discussed during JuliaCon as well. I also have a gist for this kind of thing that generate generic SPMD operation code (more specific on machine learning task probably, it is less general than Hydra), but stuck at this issue of IRTools for now(FluxML/IRTools.jl#21), but I think I'll need to wait until summer to have sometime to wrap these things up and make it stable. And I'm not sure what's the status of Hydra so far, it seems to be abandoned.

But feel free to take whatever in BatchedRoutines/CuBatchedRoutines for now (in principal we need this PR be merged then commit to CuArrays to support batched_gemm API.

@Roger-luo
Copy link
Contributor

Roger-luo commented Oct 6, 2019

@mcabbott

A fallback implementation using veiws and mul!, which would work for e.g. dual numbers.

actually general speaking, fallback can be directly supported by transform the generic_matmul function into SPMD version, but a view based implementation would also be great. But I don't think it should be this PR's task, would be nice to have it as another PR.

The ability to apply constants, like mul!(Y,A,B,α,β) will do soon (and gemm! already does).

I think we just need to rewrite all the dispatch similar to mul! in Base, but this is also something I wanted to avoid, it should be easily inferred from the mul! (and actually I tried, just bad performance somehow...)

another idea I had is actually to generate these SPMD instructions directly from ASTs and save them as scripts instead of transform SSA IR during runtime, it seems to be harder, but should get rid of the performance issue easily and can be precompiled. Most routines do not require any runtime information.

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks generally good, thanks! Perhaps all this code should just go in a batched/ folder though, rather than having the batch*.jl files.


Base.parent(A::BatchedAdjOrTrans) = A.parent

(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems oddly specific (vs e.g. broadcast). Is this how base handles it -- or is there another reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I try to follow the style in the LinearAlgebra stdlib then.

Btw, this reminds me do we need a BatchedMat and BatchedVec type?

Copy link
Contributor

@Roger-luo Roger-luo Oct 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's how base handle it, and I think we don't need to handle too much cases for these two types. It's mainly defined for batched routines to implement pullback instead of a generic batched array as an intermediate solution.

I'd like to make it as simple as possible here, and treat Array as the batched type directly and define more generic solution in a separate package in the future.

Or you will end up doing the same thing in Batched (and you'll find defining these two are not enough), then why not just use its early version? Which defined more types and should just work with existing pullbacks.

@CarloLucibello
Copy link
Member

@chengchingwen is this ready for merge? @MikeInnes could we get this merged? it has been open for a very long time and the functionality here is quite useful

@chengchingwen
Copy link
Member Author

@CarloLucibello yes, I think it’s ready for merge.

@AzamatB
Copy link

AzamatB commented Feb 16, 2020

Bump. I need this


include("./batchedadjtrans.jl")

function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a docstring here?

@CarloLucibello
Copy link
Member

Should batched_mul and batched_mul! be exported? Also, after merging this, we should mention these functions in Flux's documentation

@CarloLucibello
Copy link
Member

let's merge this, we can wait a bit before exporting it

@CarloLucibello CarloLucibello merged commit 3f607e1 into FluxML:master Feb 28, 2020
@CarloLucibello
Copy link
Member

thanks!

bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Mar 2, 2020
531: Gradient for batched_mul r=CarloLucibello a=mcabbott

Adds a gradient definition for FluxML/NNlib.jl#100

Co-authored-by: Michael Abbott <me@escbook>
bors bot added a commit to JuliaGPU/CuArrays.jl that referenced this pull request Mar 10, 2020
619: NNlib batched_mul! r=maleadt a=mcabbott

This hooks up FluxML/NNlib.jl#100 for CuArrays. 

I've added an extremely simple test, but could find the right place to insert this into test/blas.jl perhaps if that would be better?

Tests fail locally, before this PR, but I'm not sure why.

I think that `_GemmFloat == CUBLAS.CublasFloat`, except that this isn't loaded yet. Is there a neater way?

Co-authored-by: Michael Abbott <me@pseudomac>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants