Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mat index bug #249

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export fastindices # Deprecated
include("lazyarray.jl")

include("axis.jl")
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, ViewAxis, FlatAxis
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, Shaped1DAxis, ViewAxis, FlatAxis

include("componentarray.jl")
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
Expand Down
22 changes: 18 additions & 4 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,22 @@ example)
"""
struct ShapedAxis{Shape} <: AbstractAxis{nothing} end
@inline ShapedAxis(Shape) = ShapedAxis{Shape}()
ShapedAxis(::Tuple{<:Int}) = FlatAxis()
# ShapedAxis(::Tuple{<:Int}) = FlatAxis()
Base.length(::ShapedAxis{Shape}) where{Shape} = prod(Shape)

struct Shaped1DAxis{Shape} <: AbstractAxis{nothing} end
ShapedAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Shaped1DAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Base.length(::Shaped1DAxis{Shape}) where {Shape} = only(Shape)

const Shape = ShapedAxis

unshape(ax) = ax
unshape(ax::ShapedAxis) = Axis(indexmap(ax))
unshape(ax::Shaped1DAxis) = Axis(indexmap(ax))

Base.size(::ShapedAxis{Shape}) where {Shape} = Shape
Base.size(::Shaped1DAxis{Shape}) where {Shape} = Shape



Expand Down Expand Up @@ -133,9 +141,9 @@ Axis(::Number) = NullAxis()
Axis(::NamedTuple{()}) = FlatAxis()
Axis(x) = FlatAxis()

const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap}


Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...))
Expand All @@ -149,6 +157,10 @@ reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
function reindex(ax::ViewAxis{OldInds,IdxMap,Ax}, offset) where {OldInds,IdxMap,Ax<:Shaped1DAxis}
NewInds = viewindex(ax) .+ offset
return ViewAxis(NewInds, Ax())
end

# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand All @@ -175,6 +187,7 @@ end

_maybe_view_axis(inds, ax::AbstractAxis) = ViewAxis(inds, ax)
_maybe_view_axis(inds, ::NullAxis) = inds[1]
_maybe_view_axis(inds, ax::Union{ShapedAxis,Shaped1DAxis}) = ViewAxis(inds, ax)

struct CombinedAxis{C,A} <: AbstractUnitRange{Int}
component_axis::C
Expand All @@ -188,6 +201,7 @@ _component_axis(ax) = FlatAxis()

_array_axis(ax::CombinedAxis) = ax.array_axis
_array_axis(ax) = ax
_array_axis(ax::Int) = Shaped1DAxis((ax,))

Base.first(ax::CombinedAxis) = first(_array_axis(ax))

Expand Down
3 changes: 2 additions & 1 deletion src/compat/static_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ end

_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
_maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, vals...) = x

@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}
Expand Down Expand Up @@ -32,4 +33,4 @@ macro static_unpack(expr)
push!(out.args, :($esc_name = static_getproperty($parent_var_name, $(Val(name)))))
end
return out
end
end
9 changes: 7 additions & 2 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =

# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
# ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ::Union{FlatAxis,Shaped1DAxis}...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
Expand Down Expand Up @@ -166,6 +167,10 @@ function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val)
)...)
return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs))
end
function make_idx(data, nt::NamedTuple{(), Tuple{}}, last_val)
out = last_index(last_val) .+ (1:length(nt))
return (data, ViewAxis(out, ShapedAxis((length(nt),))))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
len = recursive_length(data)
Expand Down Expand Up @@ -232,7 +237,7 @@ end
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = filter_by_type(Union{ShapedAxis,Shaped1DAxis}, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
Expand Down
4 changes: 3 additions & 1 deletion src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ struct ComponentIndex{Idx, Ax<:AbstractAxis}
ax::Ax
end
ComponentIndex(idx) = ComponentIndex(idx, FlatAxis())
ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,)))
ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx)))
ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis())
ComponentIndex(vax::ViewAxis{Inds,IdxMap,Ax}) where {Inds,IdxMap,Ax} = ComponentIndex(Inds, vax.ax)

Expand Down Expand Up @@ -44,4 +46,4 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol)
end
new_ax = reindex(new_ax, -first(idx)+1)
return ComponentIndex(idx, new_ax)
end
end
2 changes: 2 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap,

Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} =
print(io, "ShapedAxis($Shape)")
Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} =
print(io, "Shaped1DAxis($Shape)")

Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} =
print(io, "ViewAxis($Inds, $(Ax()))")
Expand Down
168 changes: 159 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ using Unitful
using Functors
import TruncatedStacktraces # This is loaded just to trigger the extension package

# Convert abstract unit range to a ViewAxis with ShapeAxis.
r2v(r::AbstractUnitRange) = ViewAxis(r, ShapedAxis(size(r)))

## Test setup
c = (a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45])
nt = (a = 100, b = [4, 1.3], c = c)
nt2 = (a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6]))

ax = Axis(a = 1, b = 2:3, c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)))
ax_c = (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)
ax = Axis(a = 1, b = r2v(2:3), c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I guess I don't understand why we'd need to wrap this in a new type. In the example you gave in the issue, b is a vector element, not an nx1 matrix. I don't think we should introduce a new ShapedAxis1d type if all vector elements are behaving incorrectly. We should just fix the issue of adjoint/transposition that's broken directly. So I guess specifically, the ShapedAxis needs to be created from the FlatAxis during the adjoint operation, not beforehand. And for that, we can use the normal ShapedAxis without having to introduce a new type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie My bad, just saw this comment.
Hmm... not sure I follow, but fwiw I think everything is fine with the transpose—the issue comes up later on.
The problem with the test case in issue #249, aka

ni = 4
nj = 2
nk = 3

X = ComponentVector(a=0.0, b=zeros(Float64, nk), c=zeros(Float64, ni, nj))
Y = ComponentVector(d=zeros(Float64, ni, nj))
J = Y.*X'

appears to be due to the fact that the shape of a unit range (which is used for vector components like X[:b] in the example) isn't understood by ComponentArrays.maybe_reshape. The error I get:

julia> J[:d, :b]
ERROR: DimensionMismatch: new dimensions (4, 2) must be consistent with array size 24
Stacktrace:
 [1] (::Base.var"#throw_dmrsa#328")(dims::Tuple{Int64, Int64}, len::Int64)
   @ Base ./reshapedarray.jl:41
 [2] reshape
   @ ./reshapedarray.jl:45 [inlined]
 [3] maybe_reshape
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:237 [inlined]
 [4] ComponentArray
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:52 [inlined]
 [5] macro expansion
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:0 [inlined]
 [6] _getindex(::typeof(getindex), ::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Val{:d}, ::Val{:b})
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:119
 [7] getindex
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:103 [inlined]
 [8] getindex(::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Symbol, ::Symbol)
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:102
 [9] top-level scope
   @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

shell> 

In maybe_reshape:

# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
    shapes = filter_by_type(ShapedAxis, axs...) .|> size
    shapes = reduce((tup, s) -> (tup..., s...), shapes)
    return reshape(data, shapes)
end

during J[:d, :b], axs is

axs = (ShapedAxis((4, 2)), FlatAxis())

i.e. the unit range is converted to a FlatAxis, which is then ignored when defining shapes in maybe_reshape.

The axes of the transpose X' look fine to me:

julia> getaxes(X')
(FlatAxis(), Axis(a = 1, b = 2:4, c = ViewAxis(5:12, ShapedAxis((4, 2)))))

julia> 

The point of the Shaped1DAxis is to have a type that can be a part of the NotShapedOrPartitionedAxis Union defined in axis.jl.
This allows it to skip the problematic and unnecessary maybe_reshape for vector components.
If there was a way to dispatch on ShapedAxis{Shape} when Shape was a length-1 Tuple then I don't think we'd need Shaped1DAxis, but I don't know of a way to do that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR. :-)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR.

ax_c = (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))

a = Float64[100, 4, 1.3, 1, 1, 4.4, 0.4, 2, 1, 45]
sq_mat = collect(reshape(1:9, 3, 3))
Expand All @@ -39,6 +41,9 @@ caa = ComponentArray(a = ca, b = sq_mat)

_a, _b, _c = Val.((:a, :b, :c))

ca3 = ComponentArray(a=1, b=[2, 3, 4, 5], c=reshape(6:11, 3, 2))
cmat3 = ca3 .* ca3'
cmat3check = (1:11) .* (1:11)'

## Tests
@testset "Allocations and Inference" begin
Expand Down Expand Up @@ -287,6 +292,19 @@ end
# We had to revert this because there is no way to work around
# OffsetArrays' type piracy without introducing type piracy
# ourselves because `() isa Tuple{N, <:CombinedAxis} where {N}`
# @test reshape(a, axes(ca)...) isa Vector{Float64}

# Issue #248: Indexing ComponentMatrix with FlatAxis components
@test cmat3[:a, :a] == cmat3check[1, 1]
@test cmat3[:a, :b] == cmat3check[1, 2:5]
@test cmat3[:a, :c] == reshape(cmat3check[1, 6:11], 3, 2)
@test cmat3[:b, :a] == cmat3check[2:5, 1]
@test cmat3[:b, :b] == cmat3check[2:5, 2:5]
@test cmat3[:b, :c] == reshape(cmat3check[2:5, 6:11], 4, 3, 2)
@test cmat3[:c, :a] == reshape(cmat3check[6:11, 1], 3, 2)
@test cmat3[:c, :b] == reshape(cmat3check[6:11, 2:5], 3, 2, 4)
@test cmat3[:c, :c] == reshape(cmat3check[6:11, 6:11], 3, 2, 3, 2)

@test_broken reshape(a, axes(ca)...) isa Vector{Float64}

# Issue #265: Multi-symbol indexing with matrix components
Expand Down Expand Up @@ -333,7 +351,7 @@ end
temp = deepcopy(cmat)
@test all((temp[:c, :c][:a, :a] .= 0) .== 0)

A = ComponentArray(zeros(Int, 4, 4), Axis(x = 1:4), Axis(x = 1:4))
A = ComponentArray(zeros(Int, 4, 4), Axis(x = r2v(1:4)), Axis(x = r2v(1:4)))
A[1, :] .= 1
@test A[1, :] == ComponentVector(x = ones(Int, 4))
end
Expand All @@ -357,12 +375,17 @@ end
ca = ComponentArray(a = 1, b = 2, c = [3, 4], d = (a = [5, 6, 7], b = 8))
cmat = ca * ca'

cidx = reshape((1:(2*3)) .+ 2, 2, 3)
ca2 = ComponentArray(a = 1, b = 2, c = cidx, d = (a = [9, 10, 11], b = 12))

@testset "ComponentIndex" begin
ax = getaxes(ca)[1]
@test ax[:a] == ax[1] == ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, FlatAxis())
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = 1:3, b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = 2:3))
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))
ax2 = getaxes(ca2)[1]
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))
end

@testset "KeepIndex" begin
Expand All @@ -373,14 +396,14 @@ end

@test ca[KeepIndex(1:2)] == ComponentArray(a = 1, b = 2)
@test ca[KeepIndex(1:3)] == ComponentArray([1, 2, 3], Axis(a = 1, b = 2)) # Drops c axis
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = 2:3))
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3)))
@test ca[KeepIndex(3:end)] == ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8))

@test ca[KeepIndex(:)] == ca

@test cmat[KeepIndex(:a), KeepIndex(:b)] == ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1))
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = 1:2))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = 2:3), FlatAxis())
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2)))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2)))
@test cmat[KeepIndex(2), KeepIndex(3)] == ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis())
@test cmat[KeepIndex(2), 3] == ComponentArray(b = 2 * 3)
end
Expand Down Expand Up @@ -690,6 +713,133 @@ end
# Issue #193
# Make sure we aren't doing type piracy on `reshape`
@test ndims(dropdims(ones(1,1), dims=(1,2))) == 0

if VERSION >= v"1.9"
# `stack` was introduced in Julia 1.9
# Issue #254
x = ComponentVector(a=[1, 2])
y = ComponentVector(a=[3, 4])
xy = stack([x, y])
# The data in `xy` should be the same as what we'd get if we used plain Vectors:
@test getdata(xy) == stack(getdata.([x, y]))
# Check the axes.
xy_ax = getaxes(xy)
# Should have two axes since xy should be a ComponentMatrix.
@test length(xy_ax) == 2
# First axis should be the same as x.
@test xy_ax[1] == only(getaxes(x))
# Second axis should be a FlatAxis.
@test xy_ax[2] == FlatAxis()

# Does the dims argument to stack work?
# Using `dims=2` should be the same as the default value.
xy2 = stack([x, y]; dims=2)
@test xy2 == xy
# Using `dims=1` should stack things vertically.
xy3 = stack([x, y]; dims=1)
@test all(xy3[1, :a] .== xy[:a, 1])
@test all(xy3[2, :a] .== xy[:a, 2])

# But can we stack 2D arrays?
x = ComponentVector(a=[1, 2])
y = ComponentVector(b=[3, 4])
X = x .* y'
Y = x .* y' .+ 4
XY = stack([X, Y])
# The data in `XY` should be the same as what we'd get if we used plain Vectors:
@test getdata(XY) == stack(getdata.([X, Y]))
# Check the axes.
XY_ax = getaxes(XY)
# Should have three axes since XY should be a 3D ComponentArray.
@test length(XY_ax) == 3
# First two axes should be the same as XY.
@test XY_ax[1] == getaxes(XY)[1]
@test XY_ax[2] == getaxes(XY)[2]
# Third should be a FlatAxis.
@test XY_ax[3] == FlatAxis()
# Should test indexing too.
@test all(XY[:a, :b, 1] .== X)
@test all(XY[:a, :b, 2] .== Y)

# Make sure the dims argument works.
# Using `dims=3` should be the same as the default value.
XY_d3 = stack([X, Y]; dims=3)
@test XY_d3 == XY
# Using `dims=2` stacks along the second axis.
XY_d2 = stack([X, Y]; dims=2)
@test all(XY_d2[:a, 1, :b] .== XY[:a, :b, 1])
@test all(XY_d2[:a, 2, :b] .== XY[:a, :b, 2])
# Using `dims=1` stacks along the first axis.
XY_d1 = stack([X, Y]; dims=1)
@test all(XY_d1[1, :a, :b] .== XY[:a, :b, 1])
@test all(XY_d1[2, :a, :b] .== XY[:a, :b, 2])

# Issue #254, tuple of arrays:
x = ComponentVector(a=[1, 2])
y = ComponentVector(b=[3, 4])
Xstack1 = stack((x, y, x); dims=1)
Xstack1_noca = stack((getdata(x), getdata(y), getdata(x)); dims=1)
@test all(Xstack1 .== Xstack1_noca)
@test all(Xstack1[1, :a] .== Xstack1_noca[1, :])
@test all(Xstack1[2, :a] .== Xstack1_noca[2, :])

# Issue #254, Array of tuples.
Xstack2 = stack(ComponentArray(a=(1,2,3), b=(4,5,6)))
Xstack2_noca = stack([(1,2,3), (4,5,6)])
@test all(Xstack2 .== Xstack2_noca)
@test all(Xstack2[:, :a] .== Xstack2_noca[:, 1])
@test all(Xstack2[:, :b] .== Xstack2_noca[:, 2])

Xstack2_d1 = stack(ComponentArray(a=(1,2,3), b=(4,5,6)); dims=1)
Xstack2_noca_d1 = stack([(1,2,3), (4,5,6)]; dims=1)
@test all(Xstack2_d1 .== Xstack2_noca_d1)
@test all(Xstack2_d1[:a, :] .== Xstack2_noca_d1[1, :])
@test all(Xstack2_d1[:b, :] .== Xstack2_noca_d1[2, :])

# Issue #254, generator of arrays.
Xstack3 = stack(ComponentArray(z=[x,x]) for x in 1:4)
Xstack3_noca = stack([x, x] for x in 1:4)
# That should give me
# [1 2 3 4;
# 1 2 3 4]
@test all(Xstack3 .== Xstack3_noca)
@test all(Xstack3[:z, 1] .== Xstack3_noca[:, 1])
@test all(Xstack3[:z, 2] .== Xstack3_noca[:, 2])
@test all(Xstack3[:z, 3] .== Xstack3_noca[:, 3])
@test all(Xstack3[:z, 4] .== Xstack3_noca[:, 4])

Xstack3_d1 = stack(ComponentArray(z=[x,x]) for x in 1:4; dims=1)
Xstack3_noca_d1 = stack([x, x] for x in 1:4; dims=1)
# That should give me
# [1 1;
# 2 2;
# 3 3;
# 4 4;]
@test all(Xstack3_d1 .== Xstack3_noca_d1)
@test all(Xstack3_d1[1, :z] .== Xstack3_noca_d1[1, :])
@test all(Xstack3_d1[2, :z] .== Xstack3_noca_d1[2, :])
@test all(Xstack3_d1[3, :z] .== Xstack3_noca_d1[3, :])
@test all(Xstack3_d1[4, :z] .== Xstack3_noca_d1[4, :])

# Issue #254, map then stack.
Xstack4_d1 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=1) # map then stack
Xstack4_noca_d1 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=1) # map then stack
@test all(Xstack4_d1 .== Xstack4_noca_d1)
@test all(Xstack4_d1[:, :a] .== Xstack4_noca_d1[:, 1])
@test all(Xstack4_d1[:, :b] .== Xstack4_noca_d1[:, 2:3])

Xstack4_d2 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=2) # map then stack
Xstack4_noca_d2 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=2) # map then stack
@test all(Xstack4_d2 .== Xstack4_noca_d2)
@test all(Xstack4_d2[:a, :] .== Xstack4_noca_d2[1, :])
@test all(Xstack4_d2[:b, :] .== Xstack4_noca_d2[2:3, :])

Xstack4_dcolon = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=:) # map then stack
Xstack4_noca_dcolon = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=:) # map then stack
@test all(Xstack4_dcolon .== Xstack4_noca_dcolon)
@test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :])
@test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :])
end
end

@testset "axpy! / axpby!" begin
Expand Down
Loading