Skip to content

Commit bf99cde

Browse files
committed
improve cat inferrability
Make `cat` inferrable even if its arguments are not fully constant: ```julia julia> r = rand(Float32, 56, 56, 64, 1); julia> f(r) = cat(r, r, dims=(3,)) f (generic function with 1 method) julia> @inferred f(r); julia> last(@code_typed f(r)) Array{Float32, 4} ``` After descending into its call graph, I found that constant propagation is prohibited at `cat_t(::Type{T}, X...; dims)` due to the method instance heuristic, i.e. its body is considered to be too complex for successful inlining although it's explicitly annotated as `@inline`. But for this case, the constant propagation is greatly helpful both for abstract interpretation and optimization since it can improve the return type inference. Since it is not an easy task to improve the method instance heuristic, which is our primary logic for constant propagation, this commit does a quick fix by helping inference with the `@constprop` annotation. There is another issue that currently there is no good way to properly apply `@constprop`/`@inline` effects to a keyword function (as a note, this is a general issue of macro annotations on a method definition). So this commit also changes some internal helper functions of `cat` so that now they are not keyword ones: the changes are also necessary for the `@inline` annotation on `cat_t` to be effective to trick the method instance heuristic.
1 parent ab11173 commit bf99cde

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

base/abstractarray.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,10 +1730,9 @@ function dims2cat(dims)
17301730
ntuple(in(dims), maximum(dims))
17311731
end
17321732

1733-
_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
1733+
_cat(dims, X...) = cat_t(dims, promote_eltypeof(X...), X...)
17341734

1735-
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
1736-
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
1735+
@inline function cat_t(dims, ::Type{T}, X...) where {T}
17371736
catdims = dims2cat(dims)
17381737
shape = cat_size_shape(catdims, X...)
17391738
A = cat_similar(X[1], T, shape)
@@ -1742,6 +1741,8 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
17421741
end
17431742
return __cat(A, shape, catdims, X...)
17441743
end
1744+
# just for compat after https://github.com/JuliaLang/julia/pull/45028
1745+
@inline cat_t(::Type{T}, X...; dims) where {T} = cat_t(dims, T, X...)
17451746

17461747
# Why isn't this called `__cat!`?
17471748
__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
@@ -1880,8 +1881,8 @@ julia> reduce(hcat, vs)
18801881
"""
18811882
hcat(X...) = cat(X...; dims=Val(2))
18821883

1883-
typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1))
1884-
typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2))
1884+
typed_vcat(::Type{T}, X...) where T = cat_t(Val(1), T, X...)
1885+
typed_hcat(::Type{T}, X...) where T = cat_t(Val(2), T, X...)
18851886

18861887
"""
18871888
cat(A...; dims)
@@ -1917,7 +1918,8 @@ julia> cat(true, trues(2,2), trues(4)', dims=(1,2))
19171918
```
19181919
"""
19191920
@inline cat(A...; dims) = _cat(dims, A...)
1920-
_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims)
1921+
# `@constprop :aggressive` allows `catdims` to be propagated as constant improving return type inference
1922+
@constprop :aggressive _cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(catdims, T, A...)
19211923

19221924
# The specializations for 1 and 2 inputs are important
19231925
# especially when running with --inline=no, see #11158
@@ -1928,12 +1930,12 @@ hcat(A::AbstractArray) = cat(A; dims=Val(2))
19281930
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
19291931
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))
19301932

1931-
typed_vcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(1))
1932-
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(1))
1933-
typed_vcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(1))
1934-
typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2))
1935-
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2))
1936-
typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2))
1933+
typed_vcat(T::Type, A::AbstractArray) = cat_t(Val(1), T, A)
1934+
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val(1), T, A, B)
1935+
typed_vcat(T::Type, A::AbstractArray...) = cat_t(Val(1), T, A...)
1936+
typed_hcat(T::Type, A::AbstractArray) = cat_t(Val(2), T, A)
1937+
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val(2), T, A, B)
1938+
typed_hcat(T::Type, A::AbstractArray...) = cat_t(Val(2), T, A...)
19371939

19381940
# 2d horizontal and vertical concatenation
19391941

stdlib/LinearAlgebra/src/special.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,14 +414,14 @@ const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpo
414414

415415
promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix
416416

417-
Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(promote_eltype(xs...), xs...; dims=dims)
417+
Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(dims, promote_eltype(xs...), xs...)
418418
vcat(A::Vector...) = Base.typed_vcat(promote_eltype(A...), A...)
419419
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
420420
hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...)
421421
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
422422
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
423423
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
424-
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims)
424+
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(dims, T, xs...)
425425
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
426426
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
427427
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)

test/abstractarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,10 @@ function test_cat(::Type{TestAbstractArray})
733733
cat3v(As) = cat(As...; dims=Val(3))
734734
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
735735
@test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4)
736+
737+
r = rand(Float32, 56, 56, 64, 1);
738+
f(r) = cat(r, r, dims=(3,))
739+
@inferred f(r);
736740
end
737741

738742
function test_ind2sub(::Type{TestAbstractArray})

0 commit comments

Comments
 (0)