Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Make gemm_strided_batched! work with PermutedDimsArrays #664

Closed
wants to merge 7 commits into from

Conversation

mcabbott
Copy link

CUBLAS's gemm_strided_batched! accepts strides apart from the ordered ones of a CuArray. This is potentially useful as it saves on performing permutedims before operations, and the obvious way to expose this is to let it accept PermutedDimsArrays. This PR is an attempt to make that happen... while trying to avoid problems with wrappers & dispatch.

  1. I presume that by the time anyone is calling CuArrays.CUBLAS.gemm_strided_batched!, they are well aware that the operation is to be done by CUBLAS, and not expecting dispatch to redirect to the CPU if it's an Array. So I changed this to accept any AbstractArray.

  2. It needs a pointer CuPtr{T} which wasn't defined for PermutedDimsArray, I've added a generic definition which unwraps this recursively. (Not sure where such a defn should actually live.) Then it can run.

  3. Ideally NNlib.batched_mul will be a friendly function which dispatches every kind of array to the best routine. This is a bit of a mess right now, it needs Improvements to batched_mul, including PermutedDimsArray FluxML/NNlib.jl#187 (which similarly tries to allow for PermutedDimsArray on the CPU, but maybe not all the same ones) so don't look too hard just yet. It also needs similar() to unwrap and see the CuArray, similar(PermutedDimsArray(::CuArray)) isa Array #658 fixed by Use parent for similar(::PermutedDimsArray) JuliaLang/julia#35304 + Compat.jl I suppose, but for copied in here to try things out.

  4. This still isn't quite the most general thing, as there are matrix * 3-tensor contractions which should perhaps be done by this routine, e.g. the table here with something like strides(A)==(1,N,1).

Here's a test of steps 1 & 2:

function batched_cu(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where {T}
    axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
    C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3)))
    CuArrays.CUBLAS.gemm_strided_batched!('N', 'N', one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C)
    C
end

x = randn(3,3,3); y = randn(3,3,3); cx = cu(x); cy = cu(y);

function testall(f::Function, x, y)
    for a in 1:3, b in 1:3, c in 1:3
        perm = (a,b,c)
        isperm(perm) || continue
        xp = PermutedDimsArray(x, perm)
        resx = batched_mul(permutedims(x,perm), y)
        try
            batx = f(xp, y)
            if resx  batx
                @info "$f + perm = $perm is ok."
            else
                @warn "$f + perm = $perm gives wrong answer!" strides(xp)
            end
        catch err
            @error "$f + perm = $perm gives an error:" err strides(xp)
        end
    end
end

which gives

julia> testall(batched_cu, cx, cy)
[ Info: batched_cu + perm = (1, 2, 3) is ok.
[ Info: batched_cu + perm = (1, 3, 2) is ok.
 ** On entry to SGEMM  parameter number 8 had an illegal value
┌ Error: batched_cu + perm = (2, 1, 3) gives an error:
│   err = CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
│   strides(xp) = (3, 1, 9)
└ @ Main REPL[5]:15
┌ Warning: batched_cu + perm = (2, 3, 1) gives wrong answer!
│   strides(xp) = (3, 9, 1)
└ @ Main REPL[5]:12
 ** On entry to SGEMM  parameter number 8 had an illegal value
┌ Error: batched_cu + perm = (3, 1, 2) gives an error:
│   err = CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
│   strides(xp) = (9, 1, 3)
└ @ Main REPL[5]:15
┌ Warning: batched_cu + perm = (3, 2, 1) gives wrong answer!
│   strides(xp) = (9, 3, 1)
└ @ Main REPL[5]:12

The failures for perm = (2, 1, 3) etc. are fine, we should adjust the permutation & use the 'T' variant, and my messy NNlib code tries to do this (and needs to do it on CPU too).

The wrong answer for perm = (3, 2, 1) is more curious, note that it's not completely wrong:

perm = (3,2,1) # or perm = (2,3,1)
resx = batched_cu(permutedims(cx, perm), cy)
batx = batched_cu(PermutedDimsArray(cx, perm), cy)
resx[1,:,:]  batx[1,:,:] # true

resy = batched_cu(cx,permutedims(cy, perm))
baty = batched_cu(cx,PermutedDimsArray(cy, perm)) # here nothing is right
intersect(round.(collect(vec(resy)), digits=3), round.(collect(vec(resx)), digits=3)) # usually none

Perhaps this is just some mistake with strides, or perhaps this case is in fact not supported, I'm not sure.

c.c. @Roger-luo who has also messed with these things.

@mcabbott mcabbott changed the title Stridedbatched Make gemm_strided_batched! work with PermutedDimsArrays Mar 31, 2020
@mcabbott
Copy link
Author

I also have a crude attempt at benchmarking things, on an very slow card. Would be curious whether the gain is similar on faster cards (6x on size 128 here, 4x on size 256).

using CuArrays, NNlib, BenchmarkTools
xx, yy = randn(Float32, 128,128,128), randn(Float32, 128,128,128); cxx = cu(xx); cyy = cu(yy);
CuArrays.allowscalar(false)

# simple batched_cu, not inserting transposes, and including some wrong answers!
for a in 1:3, b in 1:3, c in 1:3
    perm_ = (a,b,c)
    isperm(perm_) || continue
    eager = @belapsed CuArrays.@sync batched_cu(permutedims($cxx, $perm_), permutedims($cyy, $perm_))
    try
        lazy = @belapsed CuArrays.@sync batched_cu(PermutedDimsArray($cxx,$perm_), PermutedDimsArray($cyy,$perm_))
        @info "perm = $perm_ times:" eager lazy eager/lazy
    catch err
        @error "perm = $perm_ gives an error" err
    end
end
# ┌ Info: perm = (1, 2, 3) times:
# │   eager = 0.011564265
# │   lazy = 0.001926885
# └   eager / lazy = 6.001533563238077
# ┌ Info: perm = (1, 3, 2) times:
# │   eager = 0.011582102
# │   lazy = 0.001946393
# └   eager / lazy = 5.9505464723722294
#  ** On entry to SGEMM  parameter number 8 had an illegal value
# ┌ Error: perm = (2, 1, 3) gives an error
# │   err = CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
# └ @ Main REPL[97]:9
# ...

# sanity checks?
@btime CuArrays.@sync batched_cu($cxx, $cyy);     #  1.948 ms (17 allocations: 512 bytes)
@btime CuArrays.@sync batched_mul($cxx, $cyy);    #  1.922 ms (17 allocations: 512 bytes)

@btime CuArrays.@sync permutedims($cxx, (1,2,3)); #  4.997 ms (89 allocations: 2.86 KiB)
@btime CuArrays.@sync permutedims($cxx, (3,2,1)); #  8.670 ms (89 allocations: 2.86 KiB)

@maleadt
Copy link
Member

maleadt commented Apr 2, 2020

cc @haampie, IIUC you were looking at reworking BLAS wrappers for use with Base's strided abstractions too.

@mcabbott
Copy link
Author

mcabbott commented Apr 3, 2020

Two approaches to how point 3 should work are FluxML/NNlib.jl#187 (using ArrayLayouts) and FluxML/NNlib.jl#191 (directly checking strides). What's here is awkwardly in between and may not work with either of them right now should now work with FluxML/NNlib.jl#191. That also exposes a 5-arg batched_mul!(C, A, B, α, β).

Point 4 is now also addressed. Allowing size(B,3)==1 means that you can perform @einsum C[i,k,b] = A[i,j,b] * B[j,k,1] without permutedims, which appears to be a few times quicker.

@@ -937,15 +937,16 @@ for (fname, elty) in
function gemm_strided_batched!(transA::Char,
transB::Char,
alpha::($elty),
A::CuArray{$elty, 3},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when I tried to wrap this I intended to use this as a low level API, now since both CUDAnative and CuArrays changed a lot, maybe we need a more bare wrapper (like a pointer type CuPtr) directly wraps the CUBLAS API? then it'd be more elegant to have a higher level wrapper for different Julia array types.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can NNlib.batched_mul be this higher-level wrapper? FluxML/NNlib.jl#191 makes it more flexible, and able to dispatch according to the underlying data.

And what can't you do with this wrapper (which works on any AbstractArray for which this pointer exists) which you could do with a different one?

@mcabbott
Copy link
Author

Now copied to JuliaGPU/CUDA.jl#539

@mcabbott mcabbott closed this Nov 11, 2020
mcabbott pushed a commit to mcabbott/CUDA.jl that referenced this pull request Nov 12, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants