Skip to content

Commit 27d69ec

Browse files
committed
add Metal extension for batched_mul
1 parent 0213868 commit 27d69ec

File tree

6 files changed

+179
-0
lines changed

6 files changed

+179
-0
lines changed

.buildkite/pipeline.yml

+21
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,27 @@ steps:
5555
NNLIB_TEST_CPU: "false"
5656
JULIA_NUM_THREADS: 4
5757

58+
- label: ":julia: Julia 1 + Metal GPU"
59+
plugins:
60+
- JuliaCI/julia#v1:
61+
version: "1"
62+
- JuliaCI/julia-test#v1:
63+
test_args: "--quickfail"
64+
- JuliaCI/julia-coverage#v1:
65+
codecov: true
66+
dirs:
67+
- src
68+
- ext
69+
agents:
70+
queue: "juliaecosystem"
71+
os: "macos"
72+
arch: "aarch64"
73+
timeout_in_minutes: 180
74+
env:
75+
NNLIB_TEST_METAL: "true"
76+
NNLIB_TEST_CPU: "false"
77+
JULIA_NUM_THREADS: 4
78+
5879
- label: "Benchmarks"
5980
plugins:
6081
- JuliaCI/julia#v1:

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1919
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
2020
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
22+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2223

2324
[extensions]
2425
NNlibAMDGPUExt = "AMDGPU"
@@ -27,6 +28,7 @@ NNlibCUDAExt = "CUDA"
2728
NNlibEnzymeCoreExt = "EnzymeCore"
2829
NNlibFFTWExt = "FFTW"
2930
NNlibForwardDiffExt = "ForwardDiff"
31+
NNlibMetalExt = "Metal"
3032

3133
[compat]
3234
AMDGPU = "0.9.4, 1"
@@ -40,6 +42,7 @@ ForwardDiff = "0.10.36"
4042
GPUArraysCore = "0.1"
4143
KernelAbstractions = "0.9.2"
4244
LinearAlgebra = "<0.0.1, 1"
45+
Metal = "1.4.2"
4346
Random = "<0.0.1, 1"
4447
Statistics = "1"
4548
cuDNN = "1"

ext/NNlibMetalExt/NNlibMetalExt.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
module NNlibMetalExt
2+
3+
using Metal, NNlib
4+
5+
# Random
6+
NNlib._rng_from_array(::MtlArray) = Metal.MPS.default_rng()
7+
8+
NNlib._rng_compat_array(rng::Metal.MPS.RNG, A::MtlArray) = nothing
9+
NNlib._rng_compat_array(rng::AbstractRNG, A::MtlArray) = throw(ArgumentError(
10+
"cannot use rng::$(typeof(rng)) with array::MtlArray, only Metal's own RNG type works"))
11+
12+
# Batched matrix multiplication
13+
function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C)
14+
eltype(C) <: Complex && @warn "don't trust this on complex arrays!" transA transB
15+
Metal.MPS.matmul!(C, A, B, α, β, transA != 'N', transB != 'N') # transA, transB, α, A, B, β, C)
16+
end
17+
18+
#=
19+
20+
help?> Metal.MPS.matmul!
21+
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
22+
transpose_left=false, transpose_right=false)
23+
24+
A MPSMatrixMultiplication kernel thay computes: c = alpha * op(a) * beta * op(b) + beta * C
25+
26+
This function should not typically be used. Rather, use the normal LinearAlgebra interface with
27+
any MtlArray and it should be accelerated using Metal Performance Shaders.
28+
29+
=#
30+
31+
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
32+
using Adapt
33+
using Adapt: WrappedArray
34+
35+
const MetalBatchedAdjoint{T} = BatchedAdjoint{T, <: MtlArray{T}}
36+
const MetalBatchedTranspose{T} = BatchedTranspose{T, <: MtlArray{T}}
37+
const MetalBatchedAdjOrTrans{T} = Union{MetalBatchedAdjoint{T}, MetalBatchedTranspose{T}}
38+
const WrappedMetalBatchedAdjOrTrans{T, N} = WrappedArray{T, N, MetalBatchedAdjOrTrans{T}, MetalBatchedAdjOrTrans{T}}
39+
40+
Base.print_array(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
41+
Base._show_nonempty(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
42+
Base.show_vector(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)
43+
44+
Base.convert(::Type{T}, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
45+
Base.Array{T, N}(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
46+
Base.collect(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = collect(adapt(Array, b))
47+
48+
49+
end # module NNlibMetalExt

test/ext_metal/batched_mul.jl

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
@testset "batched_mul" begin
2+
using NNlib: batched_mul, batched_mul!, batched_vec,
3+
batched_adjoint, batched_transpose
4+
5+
A = randn(Float32, 3,3,2);
6+
B = randn(Float32, 3,3,2);
7+
8+
C = batched_mul(A, B)
9+
@test MtlArray(C) batched_mul(MtlArray(A), MtlArray(B))
10+
11+
Ct = batched_mul(batched_transpose(A), B)
12+
@test MtlArray(Ct) batched_mul(batched_transpose(MtlArray(A)), MtlArray(B))
13+
14+
Ca = batched_mul(A, batched_adjoint(B))
15+
@test MtlArray(Ca) batched_mul(MtlArray(A), batched_adjoint(MtlArray(B)))
16+
17+
# 5-arg batched_mul!
18+
C .= pi
19+
batched_mul!(C, A, B, 2f0, 3f0)
20+
gpuCpi = MtlArray(similar(C)) .= pi
21+
@test MtlArray(C) batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0)
22+
23+
# PermutedDimsArray
24+
@test MtlArray(Ct) batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B))
25+
26+
D = permutedims(B, (1,3,2))
27+
Cp = batched_mul(batched_adjoint(A), B)
28+
@test_broken MtlArray(Cp) batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2)))
29+
30+
# Methods which reshape
31+
M = randn(Float32, 3,3)
32+
33+
Cm = batched_mul(A, M)
34+
@test MtlArray(Cm) batched_mul(MtlArray(A), MtlArray(M))
35+
36+
Cv = batched_vec(permutedims(A,(3,1,2)), M)
37+
@test_broken MtlArray(Cv) batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M))
38+
end
39+
40+
function print_array_strs(x)
41+
str = sprint((io, x)->show(io, MIME"text/plain"(), x), x)
42+
return @view split(str, '\n')[2:end]
43+
end
44+
45+
@testset "BatchedAdjOrTrans" begin
46+
x = rand(Float32, 3, 4, 2)
47+
y = MtlArray(x)
48+
49+
bax = batched_adjoint(x)
50+
btx = batched_transpose(x)
51+
bay = batched_adjoint(y)
52+
bty = batched_transpose(y)
53+
54+
@test sprint(show, bax) == sprint(show, bay)
55+
@test sprint(show, btx) == sprint(show, bty)
56+
57+
@test print_array_strs(bax) == print_array_strs(bay)
58+
@test print_array_strs(btx) == print_array_strs(bty)
59+
60+
@test Array(bax) == Array(bay)
61+
@test collect(bax) == collect(bay)
62+
@test Array(btx) == Array(bty)
63+
@test collect(btx) == collect(bty)
64+
65+
for shape in (:, (12, 2))
66+
rbax = reshape(bax, shape)
67+
rbtx = reshape(btx, shape)
68+
rbay = reshape(bay, shape)
69+
rbty = reshape(bty, shape)
70+
71+
@test sprint(show, rbax) == sprint(show, rbay)
72+
@test sprint(show, rbtx) == sprint(show, rbty)
73+
74+
@test print_array_strs(rbax) == print_array_strs(rbay)
75+
@test print_array_strs(rbtx) == print_array_strs(rbty)
76+
77+
@test Array(rbax) == Array(rbay)
78+
@test collect(rbax) == collect(rbay)
79+
@test Array(rbtx) == Array(rbty)
80+
@test collect(rbtx) == collect(rbty)
81+
end
82+
end

test/ext_metal/runtests.jl

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
Metal.allowscalar(false)
3+
4+
@testset "Batched multiplication" begin
5+
include("batched_mul.jl")
6+
end

test/runtests.jl

+18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv
2323

2424
# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
2525
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
26+
# ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests
2627
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests
2728

2829
const rng = StableRNG(123)
@@ -174,4 +175,21 @@ end
174175
else
175176
@info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them."
176177
end
178+
179+
if get(ENV, "NNLIB_TEST_METAL", "false") == "true"
180+
Pkg.add(["Metal"])
181+
182+
using Metal
183+
if Metal.functional()
184+
@testset "Metal" begin
185+
# nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather")))
186+
187+
include("ext_metal/runtests.jl")
188+
end
189+
else
190+
@info "Metal.jl package is not functional. Skipping Metal tests."
191+
end
192+
else
193+
@info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them"
194+
end
177195
end

0 commit comments

Comments
 (0)