From 6db0a4f31259cee57865d131bd591382a18cca86 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 9 Dec 2015 09:01:52 -0600 Subject: [PATCH] Type-stability fixes for matmul --- base/linalg/matmul.jl | 23 ++++++++++++++++------- test/linalg/matmul.jl | 3 +++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/base/linalg/matmul.jl b/base/linalg/matmul.jl index a26df61583f1b..909794e50b326 100644 --- a/base/linalg/matmul.jl +++ b/base/linalg/matmul.jl @@ -348,6 +348,7 @@ function copy!{R,S}(B::AbstractVecOrMat{R}, ir_dest::UnitRange{Int}, jr_dest::Un Base.copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) tM == 'C' && conj!(B) end + B end function copy_transpose!{R,S}(B::AbstractMatrix{R}, ir_dest::UnitRange{Int}, jr_dest::UnitRange{Int}, tM::Char, M::AbstractVecOrMat{S}, ir_src::UnitRange{Int}, jr_src::UnitRange{Int}) @@ -435,15 +436,9 @@ const Abuf = Array(UInt8, tilebufsize) const Bbuf = Array(UInt8, tilebufsize) const Cbuf = Array(UInt8, tilebufsize) -function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) +function generic_matmatmul!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) - if mB != nA - throw(DimensionMismatch("matrix A has dimensions ($mA, $nB), matrix B has dimensions ($mB, $nB)")) - end - if size(C,1) != mA || size(C,2) != nB - throw(DimensionMismatch("result C has dimensions $(size(C)), needs ($mA, $nB)")) - end if mA == nA == nB == 2 return matmul2x2!(C, tA, tB, A, B) @@ -451,6 +446,20 @@ function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVe if mA == nA == nB == 3 return matmul3x3!(C, tA, tB, A, B) end + _generic_matmatmul!(C, tA, tB, A, B) +end + +generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) = _generic_matmatmul!(C, tA, tB, A, B) + +function _generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) + mA, nA = lapack_size(tA, A) + mB, nB = lapack_size(tB, B) + if mB != nA + throw(DimensionMismatch("matrix A has dimensions ($mA, $nB), matrix B has dimensions ($mB, $nB)")) + end + if size(C,1) != mA || size(C,2) != nB + throw(DimensionMismatch("result C has dimensions $(size(C)), needs ($mA, $nB)")) + end tile_size = 0 if isbits(R) && isbits(T) && isbits(S) diff --git a/test/linalg/matmul.jl b/test/linalg/matmul.jl index 79ccd57a7bf37..ea4b0f2847da8 100644 --- a/test/linalg/matmul.jl +++ b/test/linalg/matmul.jl @@ -59,6 +59,9 @@ A = rand(1:20, 5, 5) .- 10 B = rand(1:20, 5, 5) .- 10 @test At_mul_B(A, B) == A'*B @test A_mul_Bt(A, B) == A*B' +v = [1,2] +C = Array(Int, 2, 2) +@test @inferred(A_mul_Bc!(C, v, v)) == [1 2; 2 4] # Preallocated C = Array(Int, size(A, 1), size(B, 2))