Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
return dest
end

# Recursively replace wrapped matrices by their parents to improve broadcasting performance
# We may do this because the indexing within `copyto!` is restricted to the stored indices
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
args = map(x -> preprocess_broadcasted(T, x), bc.args)
Broadcast.Broadcasted(bc.f, args, bc.axes)
end
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)

function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc)
for j in axs[2]
for i in j:axs[1][end]
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
end
end
return dest
Expand All @@ -285,9 +296,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc)
for j in axs[2]
for i in 1:j
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
end
end
return dest
Expand Down
6 changes: 2 additions & 4 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,8 @@ end

uppertridata(A) = A
lowertridata(A) = A
# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type
# doesn't share its indexing with the parent
uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A)
lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A)
uppertridata(A::UpperTriangular) = parent(A)
lowertridata(A::LowerTriangular) = parent(A)

@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
Expand Down