From e0c7d413e0023c24d26373f3c62909433babb1da Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Mon, 8 Jan 2018 13:17:17 -0800 Subject: [PATCH] Make Adjoint/Transpose behave like typical constructors. --- base/linalg/adjtrans.jl | 104 +++++++++++++++++++--------------------- test/linalg/adjtrans.jl | 32 ++++++------- 2 files changed, 66 insertions(+), 70 deletions(-) diff --git a/base/linalg/adjtrans.jl b/base/linalg/adjtrans.jl index 2aa57766b47e26..d8a0f23ff81363 100644 --- a/base/linalg/adjtrans.jl +++ b/base/linalg/adjtrans.jl @@ -11,46 +11,42 @@ import Base: length, size, axes, IndexStyle, getindex, setindex!, parent, vec, c struct Adjoint{T,S} <: AbstractMatrix{T} parent::S function Adjoint{T,S}(A::S) where {T,S} - checkeltype(Adjoint, T, eltype(A)) + checkeltype_adjoint(T, eltype(A)) new(A) end end struct Transpose{T,S} <: AbstractMatrix{T} parent::S function Transpose{T,S}(A::S) where {T,S} - checkeltype(Transpose, T, eltype(A)) + checkeltype_transpose(T, eltype(A)) new(A) end end -function checkeltype(::Type{Transform}, ::Type{ResultEltype}, ::Type{ParentEltype}) where {Transform, ResultEltype, ParentEltype} - if ResultEltype !== transformtype(Transform, ParentEltype) - error(string("Element type mismatch. Tried to create an `$Transform{$ResultEltype}` ", - "from an object with eltype `$ParentEltype`, but the element type of the ", - "`$Transform` of an object with eltype `$ParentEltype` must be ", - "`$(transformtype(Transform, ParentEltype))`")) - end +function checkeltype_adjoint(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype} + ResultEltype === Base.promote_op(adjoint, ParentEltype) || error(string( + "Element type mismatch. Tried to create an `Adjoint{$ResultEltype}` ", + "from an object with eltype `$ParentEltype`, but the element type of ", + "the adjoint of an object with eltype `$ParentEltype` must be ", + "`$(Base.promote_op(adjoint, ParentEltype))`.")) return nothing end -function transformtype(::Type{O}, ::Type{S}) where {O,S} - # similar to promote_op(::Any, ::Type) - @_inline_meta - T = _return_type(O, Tuple{_default_type(S)}) - _isleaftype(S) && return _isleaftype(T) ? T : Any - return typejoin(S, T) +function checkeltype_transpose(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype} + ResultEltype === Base.promote_op(transpose, ParentEltype) || error(string( + "Element type mismatch. Tried to create a `Transpose{$ResultEltype}` ", + "from an object with eltype `$ParentEltype`, but the element type of ", + "the transpose of an object with eltype `$ParentEltype` must be ", + "`$(Base.promote_op(transpose, ParentEltype))`.")) + return nothing end # basic outer constructors -Adjoint(A) = Adjoint{transformtype(Adjoint,eltype(A)),typeof(A)}(A) -Transpose(A) = Transpose{transformtype(Transpose,eltype(A)),typeof(A)}(A) - -# numbers are the end of the line -Adjoint(x::Number) = adjoint(x) -Transpose(x::Number) = transpose(x) +Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A) +Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A) -# unwrapping constructors -Adjoint(A::Adjoint) = A.parent -Transpose(A::Transpose) = A.parent +# no-op constructors for already-wrapped objects +Adjoint(A::Adjoint) = A +Transpose(A::Transpose) = A # wrapping lowercase quasi-constructors adjoint(A::AbstractVecOrMat) = Adjoint(A) @@ -80,6 +76,7 @@ julia> transpose(A) ``` """ transpose(A::AbstractVecOrMat) = Transpose(A) + # unwrapping lowercase quasi-constructors adjoint(A::Adjoint) = A.parent transpose(A::Transpose) = A.parent @@ -95,10 +92,8 @@ const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector} const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix} # for internal use below -wrappertype(A::Adjoint) = Adjoint -wrappertype(A::Transpose) = Transpose -wrappertype(::Type{<:Adjoint}) = Adjoint -wrappertype(::Type{<:Transpose}) = Transpose +wrapperop(A::Adjoint) = adjoint +wrapperop(A::Transpose) = transpose # AbstractArray interface, basic definitions length(A::AdjOrTrans) = length(A.parent) @@ -108,13 +103,13 @@ axes(v::AdjOrTransAbsVec) = (Base.OneTo(1), axes(v.parent)...) axes(A::AdjOrTransAbsMat) = reverse(axes(A.parent)) IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear() IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian() -@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrappertype(v)(v.parent[i]) -@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrappertype(A)(A.parent[j, i]) -@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrappertype(v)(x), i); v) -@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrappertype(A)(x), j, i); A) +@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrapperop(v)(v.parent[i]) +@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrapperop(A)(A.parent[j, i]) +@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i); v) +@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A) # AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate -@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrappertype(v)(v.parent[is]) -@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrappertype(v)(v.parent[:]) +@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is]) +@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:]) # conversion of underlying storage convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent)) @@ -122,8 +117,8 @@ convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(conve # for vectors, the semantics of the wrapped and unwrapped types differ # so attempt to maintain both the parent and wrapper type insofar as possible -similar(A::AdjOrTransAbsVec) = wrappertype(A)(similar(A.parent)) -similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrappertype(A)(similar(A.parent, transformtype(wrappertype(A), T))) +similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent)) +similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrapperop(A)(similar(A.parent, Base.promote_op(wrapperop(A), T))) # for matrices, the semantics of the wrapped and unwrapped types are generally the same # and as you are allocating with similar anyway, you might as well get something unwrapped similar(A::AdjOrTrans) = similar(A.parent, eltype(A), size(A)) @@ -142,15 +137,16 @@ isless(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = isless(parent(A), parent(B)) # to retain the associated semantics post-concatenation hcat(avs::Union{Number,AdjointAbsVec}...) = _adjoint_hcat(avs...) hcat(tvs::Union{Number,TransposeAbsVec}...) = _transpose_hcat(tvs...) -_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = Adjoint(vcat(map(Adjoint, avs)...)) -_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = Transpose(vcat(map(Transpose, tvs)...)) -typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = Adjoint(typed_vcat(T, map(Adjoint, avs)...)) -typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = Transpose(typed_vcat(T, map(Transpose, tvs)...)) +_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = adjoint(vcat(map(adjoint, avs)...)) +_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = transpose(vcat(map(transpose, tvs)...)) +typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = adjoint(typed_vcat(T, map(adjoint, avs)...)) +typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = transpose(typed_vcat(T, map(transpose, tvs)...)) # otherwise-redundant definitions necessary to prevent hitting the concat methods in sparse/sparsevector.jl hcat(avs::Adjoint{<:Any,<:Vector}...) = _adjoint_hcat(avs...) hcat(tvs::Transpose{<:Any,<:Vector}...) = _transpose_hcat(tvs...) hcat(avs::Adjoint{T,Vector{T}}...) where {T} = _adjoint_hcat(avs...) hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...) +# TODO unify and allow mixed combinations ### higher order functions @@ -158,14 +154,14 @@ hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...) # to retain the associated semantics post-map/broadcast # # note that the caller's operation f operates in the domain of the wrapped vectors' entries. -# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries. -map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...)) -map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...)) +# hence the adjoint->f->adjoint shenanigans applied to the parent vectors' entries. +map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...)), parent.(avs)...)) +map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...)) quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below -broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparenta.(avs)...)) -broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparentt.(tvs)...)) - +broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...)) +broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) +# TODO unify and allow mixed combinations ### linear algebra @@ -186,11 +182,11 @@ end *(u::TransposeAbsVec, v::TransposeAbsVec) = throw(MethodError(*, (u, v))) # Adjoint/Transpose-vector * matrix -*(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) * u.parent) -*(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) * u.parent) +*(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) * u.parent) +*(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) * u.parent) # Adjoint/Transpose-vector * Adjoint/Transpose-matrix -*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Adjoint(A.parent * u.parent) -*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Transpose(A.parent * u.parent) +*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(A.parent * u.parent) +*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = transpose(A.parent * u.parent) ## pseudoinversion @@ -203,10 +199,10 @@ pinv(v::TransposeAbsVec, tol::Real = 0) = pinv(conj(v.parent)).parent ## right-division \ -/(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) \ u.parent) -/(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) \ u.parent) -/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Adjoint(conj(A.parent) \ u.parent) # technically should be Adjoint(copy(Adjoint(copy(A))) \ u.parent) -/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Transpose(conj(A.parent) \ u.parent) # technically should be Transpose(copy(Transpose(copy(A))) \ u.parent) +/(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) \ u.parent) +/(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) \ u.parent) +/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = adjoint(conj(A.parent) \ u.parent) # technically should be adjoint(copy(adjoint(copy(A))) \ u.parent) +/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = transpose(conj(A.parent) \ u.parent) # technically should be transpose(copy(transpose(copy(A))) \ u.parent) # dismabiguation methods *(A::AdjointAbsVec, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B) diff --git a/test/linalg/adjtrans.jl b/test/linalg/adjtrans.jl index c84a551a84c332..d6a068bf41d3ab 100644 --- a/test/linalg/adjtrans.jl +++ b/test/linalg/adjtrans.jl @@ -62,23 +62,12 @@ end # the tests for the inner constructors exercise abstract scalar and concrete array eltype, forgoing here end -@testset "Adjoint and Transpose of Numbers" begin - @test Adjoint(1) == 1 - @test Adjoint(1.0) == 1.0 - @test Adjoint(1im) == -1im - @test Adjoint(1.0im) == -1.0im - @test Transpose(1) == 1 - @test Transpose(1.0) == 1.0 - @test Transpose(1im) == 1im - @test Transpose(1.0im) == 1.0im -end - -@testset "Adjoint and Transpose unwrapping" begin +@testset "Adjoint and Transpose no-op on already-wrapped objects" begin intvec, intmat = [1, 2], [1 2; 3 4] - @test Adjoint(Adjoint(intvec)) === intvec - @test Adjoint(Adjoint(intmat)) === intmat - @test Transpose(Transpose(intvec)) === intvec - @test Transpose(Transpose(intmat)) === intmat + @test (A = Adjoint(intvec); Adjoint(A) === A) + @test (A = Adjoint(intmat); Adjoint(A) === A) + @test (A = Transpose(intvec); Transpose(A) === A) + @test (A = Transpose(intmat); Transpose(A) === A) end @testset "Adjoint and Transpose basic AbstractArray functionality" begin @@ -441,6 +430,17 @@ end end end +@testset "adjoint and transpose of Numbers" begin + @test adjoint(1) == 1 + @test adjoint(1.0) == 1.0 + @test adjoint(1im) == -1im + @test adjoint(1.0im) == -1.0im + @test transpose(1) == 1 + @test transpose(1.0) == 1.0 + @test transpose(1im) == 1im + @test transpose(1.0im) == 1.0im +end + @testset "adjoint!(a, b) return a" begin a = fill(1.0+im, 5) b = fill(1.0+im, 1, 5)