Skip to content

Commit

Permalink
Make matrix multiplication work for more types
Browse files Browse the repository at this point in the history
Currently it is assumed that the type of a sum of x::T and y::T
is T but this may not be the case
  • Loading branch information
blegat committed Aug 25, 2016
1 parent 52346ec commit bacf5ba
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
34 changes: 18 additions & 16 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
arithtype(T) = T
arithtype(::Type{Bool}) = Int

matprod(x, y) = x*y + x*y

# multiply by diagonal matrix as vector
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
m, n = size(A)
Expand Down Expand Up @@ -76,11 +78,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(

# Matrix-vector multiplication
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
Expand All @@ -99,22 +101,22 @@ end
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Expand All @@ -125,7 +127,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
# Matrix-matrix multiplication

function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
Expand All @@ -142,14 +144,14 @@ end
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
Expand All @@ -166,7 +168,7 @@ end
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
Expand All @@ -175,7 +177,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Expand All @@ -184,14 +186,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!(similar(B, promote_op(matprod, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
Expand Down Expand Up @@ -424,7 +426,7 @@ end
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
C = similar(B, promote_op(matprod, arithtype(T), arithtype(S)), mA, nB)
generic_matmatmul!(C, tA, tB, A, B)
end

Expand Down Expand Up @@ -618,7 +620,7 @@ end

# multiply 2x2 matrices
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B)
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down Expand Up @@ -647,7 +649,7 @@ end

# Multiply 3x3 matrices
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B)
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down
30 changes: 30 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,33 @@ let
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end

# #18218
module TestPR18218
using Base.Test
import Base.*, Base.+, Base.zero
immutable TypeA
x::Int
end
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
immutable TypeB
x::Int
end
Base.convert(::Type{TypeB}, x::Int) = TypeB(x)
immutable TypeC
x::Int
end
immutable TypeD
x::Int
end
Base.convert(::Type{TypeD}, x::Int) = TypeD(x)
zero(d::TypeD) = TypeD(0)
zero(::Type{TypeD}) = TypeD(0)
(*)(a::Union{TypeA,TypeB}, b::Union{TypeA,TypeB}) = TypeC(a.x*b.x)
(+)(a::Union{TypeC,TypeD}, b::Union{TypeC,TypeD}) = TypeD(a.x+b.x)
A = TypeA[1 2; 3 4]
b = TypeB[1, 2]
d = A * b
@test typeof(d) == Vector{TypeD}
@test d == TypeD[5, 11]
end

0 comments on commit bacf5ba

Please sign in to comment.