Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions base/slicearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,36 @@ unitaxis(::AbstractArray) = Base.OneTo(1)

function Slices(A::P, slicemap::SM, ax::AX) where {P,SM,AX}
N = length(ax)
argT = map((a,l) -> l === (:) ? Colon : eltype(a), axes(A), slicemap)
parent_axes = ntuple(d -> axes(A, d), length(slicemap))
argT = map((a,l) -> l === (:) ? Colon : eltype(a), parent_axes, slicemap)
S = Base.promote_op(view, P, argT...)
Slices{P,SM,AX,S,N}(A, slicemap, ax)
end

_slice_check_dims(N) = nothing
function _slice_check_dims(N, dim, dims...)
1 <= dim <= N || throw(DimensionMismatch("Invalid dimension $dim"))
1 <= dim || throw(DimensionMismatch("Invalid dimension $dim"))
dim in dims && throw(DimensionMismatch("Dimensions $dims are not unique"))
_slice_check_dims(N,dims...)
end

@constprop :aggressive function _eachslice(A::AbstractArray{T,N}, dims::NTuple{M,Integer}, drop::Bool) where {T,N,M}
_slice_check_dims(N,dims...)
N_ = foldl(max, dims; init=N)

if drop
# if N = 4, dims = (3,1) then
# axes = (axes(A,3), axes(A,1))
# slicemap = (2, :, 1, :)
ax = map(dim -> axes(A,dim), dims)
slicemap = ntuple(dim -> something(findfirst(isequal(dim), dims), (:)), N)
slicemap = ntuple(dim -> something(findfirst(isequal(dim), dims), (:)), N_)
return Slices(A, slicemap, ax)
else
# if N = 4, dims = (3,1) then
# axes = (axes(A,1), OneTo(1), axes(A,3), OneTo(1))
# slicemap = (1, :, 3, :)
ax = ntuple(dim -> dim in dims ? axes(A,dim) : unitaxis(A), N)
slicemap = ntuple(dim -> dim in dims ? dim : (:), N)
ax = ntuple(dim -> dim in dims ? axes(A,dim) : unitaxis(A), N_)
slicemap = ntuple(dim -> dim in dims ? dim : (:), N_)
return Slices(A, slicemap, ax)
end
end
Expand Down
26 changes: 26 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,32 @@ end
@test_throws BoundsError A[2,3] = [4,5]
@test_throws BoundsError A[2,3] .= [4,5]
end

@testset "trailing dimensions" begin
v = collect(1:3)

S2 = eachslice(v; dims = 2, drop=true)
@test S2 isa AbstractSlices{<:AbstractVector, 1}
@test size(S2) == (1,)
@test S2[1] == v

S2K = eachslice(v; dims = 2, drop=false)
@test S2K isa AbstractSlices{<:AbstractVector, 2}
@test size(S2K) == (1,1)
@test S2K[1,1] == v

M = reshape(1:6, 2, 3)

S13 = eachslice(M; dims = (1,3))
@test size(S13) == (2,1)
@test S13[2,1] == M[2,:,1]

S13K = eachslice(M; dims = (1,3), drop=false)
@test size(S13K) == (2,1,1)
@test S13K[1,1,1] == M[1,:]
@test S13K[2,1,1] == M[2,:]
end

end

###
Expand Down