-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OffsetArray support for cat/vcat/hcat #37629
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1442,13 +1442,13 @@ function _typed_vcat(::Type{T}, V::AbstractVecOrTuple{AbstractVector}) where T | |||||
for Vk in V | ||||||
n += Int(length(Vk))::Int | ||||||
end | ||||||
a = similar(V[1], T, n) | ||||||
pos = 1 | ||||||
for k=1:Int(length(V))::Int | ||||||
a = similar(first(V), T, n) | ||||||
pos = first(axes(a, 1)) | ||||||
for k = eachindex(V) | ||||||
Vk = V[k] | ||||||
p1 = pos + Int(length(Vk))::Int - 1 | ||||||
a[pos:p1] = Vk | ||||||
pos = p1+1 | ||||||
n = length(Vk) | ||||||
copyto!(a, pos, Vk, first(axes(Vk, 1)), n) | ||||||
pos += n | ||||||
end | ||||||
a | ||||||
end | ||||||
|
@@ -1459,11 +1459,10 @@ hcat(A::AbstractVecOrMat...) = typed_hcat(promote_eltype(A...), A...) | |||||
hcat(A::AbstractVecOrMat{T}...) where {T} = typed_hcat(T, A...) | ||||||
|
||||||
function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T | ||||||
nargs = length(A) | ||||||
nrows = size(A[1], 1) | ||||||
nrows = size(first(A), 1) | ||||||
ncols = 0 | ||||||
dense = true | ||||||
for j = 1:nargs | ||||||
for j = eachindex(A) | ||||||
Aj = A[j] | ||||||
if size(Aj, 1) != nrows | ||||||
throw(ArgumentError("number of rows of each array must match (got $(map(x->size(x,1), A)))")) | ||||||
|
@@ -1472,17 +1471,17 @@ function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T | |||||
nd = ndims(Aj) | ||||||
ncols += (nd==2 ? size(Aj,2) : 1) | ||||||
end | ||||||
B = similar(A[1], T, nrows, ncols) | ||||||
pos = 1 | ||||||
B = similar(first(A), T, nrows, ncols) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How difficult would it be to adapt this to have I guess that's only the case of a vector of vectors; when some elements are matrices you'd have to give up on the 2nd axis. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand your comment correctly. As explained in JuliaArrays/OffsetArrays.jl#63 (comment), a meaningful offset propagation is only possible for very restricted conditions. This restriction doesn't make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I didn't see the issue, let me look there first. |
||||||
pos = first(axes(B, 1)) | ||||||
if dense | ||||||
for k=1:nargs | ||||||
for k=eachindex(A) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spaces around |
||||||
Ak = A[k] | ||||||
n = length(Ak) | ||||||
copyto!(B, pos, Ak, 1, n) | ||||||
pos += n | ||||||
end | ||||||
else | ||||||
for k=1:nargs | ||||||
for k=eachindex(A) | ||||||
Ak = A[k] | ||||||
p1 = pos+(isa(Ak,AbstractMatrix) ? size(Ak, 2) : 1)-1 | ||||||
B[:, pos:p1] = Ak | ||||||
|
@@ -1496,17 +1495,16 @@ vcat(A::AbstractVecOrMat...) = typed_vcat(promote_eltype(A...), A...) | |||||
vcat(A::AbstractVecOrMat{T}...) where {T} = typed_vcat(T, A...) | ||||||
|
||||||
function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T | ||||||
nargs = length(A) | ||||||
nrows = sum(a->size(a, 1), A)::Int | ||||||
ncols = size(A[1], 2) | ||||||
for j = 2:nargs | ||||||
ncols = size(first(A), 2) | ||||||
for j = first(axes(A))[2:end] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if size(A[j], 2) != ncols | ||||||
throw(ArgumentError("number of columns of each array must match (got $(map(x->size(x,2), A)))")) | ||||||
end | ||||||
end | ||||||
B = similar(A[1], T, nrows, ncols) | ||||||
pos = 1 | ||||||
for k=1:nargs | ||||||
B = similar(first(A), T, nrows, ncols) | ||||||
pos = first(axes(B, 1)) | ||||||
for k=eachindex(A) | ||||||
Ak = A[k] | ||||||
p1 = pos+size(Ak,1)::Int-1 | ||||||
B[pos:p1, :] = Ak | ||||||
|
@@ -1589,7 +1587,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) | |||||
@inline function _cat_t(dims, ::Type{T}, X...) where {T} | ||||||
catdims = dims2cat(dims) | ||||||
shape = cat_shape(catdims, map(cat_size, X)::Tuple{Vararg{Union{Int,Dims}}})::Dims | ||||||
A = cat_similar(X[1], T, shape) | ||||||
A = cat_similar(first(X), T, shape) | ||||||
if count(!iszero, catdims)::Int > 1 | ||||||
fill!(A, zero(T)) | ||||||
end | ||||||
|
@@ -1604,7 +1602,7 @@ function __cat(A, shape::NTuple{M,Int}, catdims, X...) where M | |||||
for x in X | ||||||
for i = 1:N | ||||||
if concat[i] | ||||||
inds[i] = offsets[i] .+ cat_indices(x, i) | ||||||
inds[i] = offsets[i] .+ parent(cat_indices(x, i)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this about? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This strips the offsets and uses linear indexing when computing indices. This could be fragile and a little bit tricky, though. Phehaps there's a better way to make things work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would generally try to stay away from assumptions about how |
||||||
offsets[i] += cat_size(x, i) | ||||||
else | ||||||
inds[i] = 1:shape[i] | ||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -22,7 +22,7 @@ using Distributed: splitrange | |||
@test splitrange(-1, 1, 4) == Array{UnitRange{Int64},1}([-1:-1,0:0,1:1]) | ||||
|
||||
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") | ||||
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl")) | ||||
isdefined(Main, :OffsetArrays) || @eval Main @everywhere include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl")) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason for this change? Could this cause problems in tests running in other workers? What if it defines methods that are supposed to fail in those tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure about this, though. Without this
Adding this helps remove this error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe Line 4 in e378767
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's already I think it's because It would be better to get #37643 first and I'll drop the OffsetArrays upgrade commit here. |
||||
using .Main.OffsetArrays | ||||
|
||||
oa = OffsetArray([123, -345], (-2,)) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because this often gets called with poor inference. Try
@code_warntype Base._typed_vcat(Float64, AbstractVector[[1,2,3], [4.0]])
.Not putting in these annotation re-opens us to all sorts of code invalidation, see https://julialang.org/blog/2020/08/invalidations/.
This isn't a general pattern you need to start adopting in all your Julia code, but this is a very low-level function called by lots of things (all of which would be invalidated if this is invalidated) and from its signature alone you might guess it's a bit of an inference nightmare.
I'm not adding comments like these elsewhere in this function, but you'll probably need to add other annotations to the initialization of
pos
above. Ideally, it would be great to check with SnoopCompile and loading various packages likeMakie
to see if you've greatly increased the number of invalidations.