Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize and fix map/broadcast over Adjoint/Transpose vectors #25219

Merged
merged 1 commit into from
Dec 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 8 additions & 24 deletions base/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,30 +136,14 @@ hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
### higher order functions
# preserve Adjoint/Transpose wrapper around vectors
# to retain the associated semantics post-map/broadcast

# vectorfy takes an Adoint/Transpose-wrapped vector and builds
# an unwrapped vector with the entrywise-same contents
vectorfy(x::Number) = x
vectorfy(adjvec::AdjointAbsVec) = map(Adjoint, adjvec.parent)
vectorfy(transvec::TransposeAbsVec) = map(Transpose, transvec.parent)
vectorfyall(transformedvecs...) = (map(vectorfy, transformedvecs)...,)

# map over collections of Adjoint/Transpose-wrapped vectors
# note that the caller's operation `f` should be applied to the entries of the wrapped
# vectors, rather than the entires of the wrapped vector's parents. so first we use vectorfy
# to build unwrapped vectors with entrywise-same contents as the wrapped input vectors.
# then we map the caller's operation over that set of unwrapped vectors. but now re-wrapping
# the resulting vector would inappropriately transform the result vector's entries. so
# instead of simply mapping the caller's operation over the set of unwrapped vectors,
# we map Adjoint/Transpose composed with the caller's operationt over the set of unwrapped
# vectors. then re-wrapping the result vector yields a wrapped vector with the correct entries.
map(f, avs::AdjointAbsVec...) = Adjoint(map(Adjoint∘f, vectorfyall(avs...)...))
map(f, tvs::TransposeAbsVec...) = Transpose(map(Transpose∘f, vectorfyall(tvs...)...))

# broadcast over collections of Adjoint/Transpose-wrapped vectors and numbers
# similar explanation for these definitions as for map above
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast(Adjoint∘f, vectorfyall(avs...)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast(Transpose∘f, vectorfyall(tvs...) ...))
#
# note that the caller's operation f operates in the domain of the wrapped vectors' entries.
# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries.
map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...))
map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...))
quasiparent(x) = parent(x); quasiparent(x::Number) = x # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparent.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparent.(tvs)...))


### linear algebra
Expand Down
8 changes: 8 additions & 0 deletions test/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ end
# trinary broadcast over wrapped vectors with concrete scalar eltype and numbers
@test broadcast(+, Adjoint(vec), 1, Adjoint(vec))::Adjoint{Complex{Int},Vector{Complex{Int}}} == avec + avec .+ 1
@test broadcast(+, Transpose(vec), 1, Transpose(vec))::Transpose{Complex{Int},Vector{Complex{Int}}} == tvec + tvec .+ 1
# ascertain inference friendliness, ref. https://github.com/JuliaLang/julia/pull/25083#issuecomment-353031641
sparsevec = SparseVector([1.0, 2.0, 3.0])
@test map(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test map(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
@test broadcast(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test broadcast(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
@test broadcast(+, Adjoint(sparsevec), 1.0, Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test broadcast(+, Transpose(sparsevec), 1.0, Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
end

@testset "Adjoint/Transpose-wrapped vector multiplication" begin
Expand Down