diff --git a/Project.toml b/Project.toml index 2c6506d7..00977c2c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUArrays" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "11.3.3" +version = "11.3.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 1558dc5a..3f19aa47 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -1,6 +1,6 @@ # integration with LinearAlgebra stdlib -using LinearAlgebra: MulAddMul, wrap, diagm +using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal ## transpose and adjoint @@ -493,6 +493,16 @@ end function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) end + # need to support mixed complex/real types too + #function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::V) where {T<:BlasReal, V<:LinearAlgebra.BlasFlag.SyrkHerkGemm} + # LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) + #end + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64} + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) + end + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{Complex{T}}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64} + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) + end # Julia 1.12 introduced generic_mul! for scalar * array operations function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, X::AbstractGPUVecOrMat, s::Number, alpha::Number, beta::Number) if length(C) != length(X) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 31637977..a9412146 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -484,6 +484,15 @@ end @test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5))) @test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix end + @testset "$(complex(T)), $(complex(T)), $T gemm C := A * B * a + C * b" for T in filter(T-><:(T, Real) && <:(T, AbstractFloat), eltypes) + Tc = complex(T) + A, B, C = rand(Tc, 4, 4), rand(T, 4, 4), rand(Tc, 4, 4) + + @test compare(*, AT, A, B) + @test compare(mul!, AT, C, A, B) + @test compare(mul!, AT, C, A, B, Ref(T(4)), Ref(T(5))) + @test typeof(AT(rand(Tc, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix + end end @testsuite "linalg/norm" (AT, eltypes)->begin