Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ julia> rmul!([NaN], 0.0)
```
"""
function rmul!(X::AbstractArray, s::Number)
isone(s) && return X
@simd for I in eachindex(X)
@inbounds X[I] *= s
end
Expand Down Expand Up @@ -318,6 +319,7 @@ julia> lmul!(0.0, [Inf])
```
"""
function lmul!(s::Number, X::AbstractArray)
isone(s) && return X
@simd for I in eachindex(X)
@inbounds X[I] = s*X[I]
end
Expand Down
91 changes: 91 additions & 0 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,93 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
matmul_size_check(size(C), (mA, nA), (mB, nB))
return _rmul_or_fill!(C, β)
end

if A === B
tA_uc = uppercase(tA) # potentially strip a WrapperChar
aat = (tA_uc == 'N')
blasfn = _valtypeparam(val)
if blasfn == BlasFlag.SYRK && T <: Union{Real,Complex} && (iszero(β) || issymmetric(C))
return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U')
elseif blasfn == BlasFlag.HERK && isreal(α) && isreal(β) && (iszero(β) || ishermitian(C))
return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true)
end
end

return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end

"""
generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}

Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e.,
``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise.
If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and
``A' A α + C β`` otherwise.
"""
function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
require_one_based_indexing(C, A)
nC = checksquare(C)
m, n = size(A, 1), size(A, 2)
mA = aat ? m : n
if nC != mA
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
end

_rmul_or_fill!(C, β)
@inbounds if !conjugate
if aat
for k ∈ 1:n, j ∈ 1:m
αA_jk = A[j, k] * α
for i ∈ 1:j
C[i, j] += A[i, k] * αA_jk
end
end
else
for j ∈ 1:n, i ∈ 1:j
temp = A[1, i] * A[1, j]
for k ∈ 2:m
temp += A[k, i] * A[k, j]
end
C[i, j] += temp * α
end
end
else
if aat
for k ∈ 1:n, j ∈ 1:m
αA_jk_bar = conj(A[j, k]) * α
for i ∈ 1:j-1
C[i, j] += A[i, k] * αA_jk_bar
end
C[j, j] += abs2(A[j, k]) * α
end
else
for j ∈ 1:n
for i ∈ 1:j-1
temp = conj(A[1, i]) * A[1, j]
for k ∈ 2:m
temp += conj(A[k, i]) * A[k, j]
end
C[i, j] += temp * α
end
temp = abs2(A[1, j])
for k ∈ 2:m
temp += abs2(A[k, j])
end
C[j, j] += temp * α
end
end
end
return C
end

# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
Expand Down Expand Up @@ -715,6 +802,8 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
else
return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U')
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
Expand Down Expand Up @@ -750,6 +839,8 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
else
return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true)
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
Expand Down
36 changes: 36 additions & 0 deletions test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,42 @@ end
@test_throws DimensionMismatch axpy!(α, x, Vector(1:3), y, Vector(1:5))
end

@testset "generic syrk & herk" begin
for T ∈ (BigFloat, Complex{BigFloat}, Quaternion{Float64})
α = randn(T)
a = randn(T, 3, 4)
csmall = similar(a, 3, 3)
csmall_fallback = similar(a, 3, 3)
cbig = similar(a, 4, 4)
cbig_fallback = similar(a, 4, 4)
mul!(csmall, a, a', real(α), false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', real(α), false)
@test ishermitian(csmall)
@test csmall ≈ csmall_fallback
mul!(cbig, a', a, real(α), false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, real(α), false)
@test ishermitian(cbig)
@test cbig ≈ cbig_fallback
mul!(csmall, a, transpose(a), α, false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, transpose(a), α, false)
@test csmall ≈ csmall_fallback
mul!(cbig, transpose(a), a, α, false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false)
@test cbig ≈ cbig_fallback
if T <: Union{Real, Complex}
@test issymmetric(csmall)
@test issymmetric(cbig)
end
#make sure generic herk is not called for non-real α
mul!(csmall, a, a', α, false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', α, false)
@test csmall ≈ csmall_fallback
mul!(cbig, a', a, α, false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, α, false)
@test cbig ≈ cbig_fallback
end
end

@test !issymmetric(fill(1,5,3))
@test !ishermitian(fill(1,5,3))
@test (x = fill(1,3); cross(x,x) == zeros(3))
Expand Down