From bacf5ba8deeed82672df8e5e8ffb6627eb1c5080 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 24 Aug 2016 15:40:48 +0200 Subject: [PATCH] Make matrix multiplication work for more types Currently it is assumed that the type of a sum of x::T and y::T is T but this may not be the case --- base/linalg/matmul.jl | 34 ++++++++++++++++++---------------- test/linalg/matmul.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/base/linalg/matmul.jl b/base/linalg/matmul.jl index aa713321a42bb5..2f44c5c2c7585f 100644 --- a/base/linalg/matmul.jl +++ b/base/linalg/matmul.jl @@ -5,6 +5,8 @@ arithtype(T) = T arithtype(::Type{Bool}) = Int +matprod(x, y) = x*y + x*y + # multiply by diagonal matrix as vector function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector) m, n = size(A) @@ -76,11 +78,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu( # Matrix-vector multiplication function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x)) end function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) A_mul_B!(similar(x,TS,size(A,1)),A,x) end (*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B @@ -99,22 +101,22 @@ end A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x) function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x)) end function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) At_mul_B!(similar(x,TS,size(A,2)), A, x) end At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x) At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x) function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x)) end function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) Ac_mul_B!(similar(x,TS,size(A,2)), A, x) end @@ -125,7 +127,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m # Matrix-matrix multiplication function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B) end A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B) @@ -142,14 +144,14 @@ end A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B) function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B) end At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B) At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B) function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B) end A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B) @@ -166,7 +168,7 @@ end A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B) function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B) end At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B) @@ -175,7 +177,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B) Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B) function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B) end Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B) @@ -184,14 +186,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B) A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B) function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) - TS = promote_op(*, arithtype(T), arithtype(S)) + TS = promote_op(matprod, arithtype(T), arithtype(S)) A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B) end A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B) A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B) Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) = - Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B) + Ac_mul_Bc!(similar(B, promote_op(matprod, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B) Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B) Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B) Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B) @@ -424,7 +426,7 @@ end function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S}) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) - C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB) + C = similar(B, promote_op(matprod, arithtype(T), arithtype(S)), mA, nB) generic_matmatmul!(C, tA, tB, A, B) end @@ -618,7 +620,7 @@ end # multiply 2x2 matrices function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) - matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B) + matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B) end function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) @@ -647,7 +649,7 @@ end # Multiply 3x3 matrices function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) - matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B) + matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B) end function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) diff --git a/test/linalg/matmul.jl b/test/linalg/matmul.jl index 4755f1d9ddd647..16391a6bc47e09 100644 --- a/test/linalg/matmul.jl +++ b/test/linalg/matmul.jl @@ -389,3 +389,33 @@ let @test_throws DimensionMismatch A_mul_B!(full43, full43, tri44) end end + +# #18218 +module TestPR18218 + using Base.Test + import Base.*, Base.+, Base.zero + immutable TypeA + x::Int + end + Base.convert(::Type{TypeA}, x::Int) = TypeA(x) + immutable TypeB + x::Int + end + Base.convert(::Type{TypeB}, x::Int) = TypeB(x) + immutable TypeC + x::Int + end + immutable TypeD + x::Int + end + Base.convert(::Type{TypeD}, x::Int) = TypeD(x) + zero(d::TypeD) = TypeD(0) + zero(::Type{TypeD}) = TypeD(0) + (*)(a::Union{TypeA,TypeB}, b::Union{TypeA,TypeB}) = TypeC(a.x*b.x) + (+)(a::Union{TypeC,TypeD}, b::Union{TypeC,TypeD}) = TypeD(a.x+b.x) + A = TypeA[1 2; 3 4] + b = TypeB[1, 2] + d = A * b + @test typeof(d) == Vector{TypeD} + @test d == TypeD[5, 11] +end