Skip to content

Commit

Permalink
ngenerate/nsplat: BitArray setindex! part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Nov 21, 2014
1 parent a030d08 commit 3477f04
Showing 1 changed file with 71 additions and 58 deletions.
129 changes: 71 additions & 58 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,28 +560,35 @@ end
# bounds check and is defined in bitarray.jl)
# (code is duplicated for safe and unsafe versions for performance reasons)

@ngenerate N typeof(B) function unsafe_setindex!(B::BitArray, x::Bool, I_0::Int, I::NTuple{N,Int}...)
stride = 1
index = I_0
@nexprs N d->begin
stride *= size(B,d)
index += (I_d - 1) * stride
stagedfunction unsafe_setindex!(B::BitArray, x::Bool, I_0::Int, I::Int...)
N = length(I)
quote
stride = 1
index = I_0
@nexprs $N d->begin
stride *= size(B,d)
index += (I[d] - 1) * stride
end
unsafe_setindex!(B, x, index)
return B
end
unsafe_setindex!(B, x, index)
return B
end

@ngenerate N typeof(B) function setindex!(B::BitArray, x::Bool, I_0::Int, I::NTuple{N,Int}...)
stride = 1
index = I_0
@nexprs N d->begin
l = size(B,d)
stride *= l
1 <= I_{d-1} <= l || throw(BoundsError())
index += (I_d - 1) * stride
stagedfunction setindex!(B::BitArray, x::Bool, I_0::Int, I::Int...)
N = length(I)
quote
stride = 1
index = I_0
@nexprs $N d->(I_d = I[d])
@nexprs $N d->begin
l = size(B,d)
stride *= l
1 <= I_{d-1} <= l || throw(BoundsError())
index += (I_d - 1) * stride
end
B[index] = x
return B
end
B[index] = x
return B
end

# contiguous multidimensional indexing: if the first dimension is a range,
Expand All @@ -603,59 +610,65 @@ function unsafe_setindex!(B::BitArray, x::Bool, I0::UnitRange{Int})
return B
end

@ngenerate N typeof(B) function unsafe_setindex!(B::BitArray, X::BitArray, I0::UnitRange{Int}, I::NTuple{N,Union(Int,UnitRange{Int})}...)
length(X) == 0 && return B
f0 = first(I0)
l0 = length(I0)
stagedfunction unsafe_setindex!(B::BitArray, X::BitArray, I0::UnitRange{Int}, I::Union(Int,UnitRange{Int})...)
N = length(I)
quote
length(X) == 0 && return B
f0 = first(I0)
l0 = length(I0)

gap_lst_1 = 0
@nexprs N d->(gap_lst_{d+1} = length(I_d))
stride = 1
ind = f0
@nexprs N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I_d) - 1)
gap_lst_{d+1} *= stride
end
gap_lst_1 = 0
@nexprs $N d->(gap_lst_{d+1} = length(I[d]))
stride = 1
ind = f0
@nexprs $N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I[d]) - 1)
gap_lst_{d+1} *= stride
end

refind = 1
@nloops(N, i, d->I_d,
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
refind = 1
@nloops($N, i, d->I[d],
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
begin # BODY
copy_chunks!(B.chunks, ind, X.chunks, refind, l0)
refind += l0
end)
end)

return B
return B
end
end

@ngenerate N typeof(B) function unsafe_setindex!(B::BitArray, x::Bool, I0::UnitRange{Int}, I::NTuple{N,Union(Int,UnitRange{Int})}...)
f0 = first(I0)
l0 = length(I0)
l0 == 0 && return B
@nexprs N d->(length(I_d) == 0 && return B)
stagedfunction unsafe_setindex!(B::BitArray, x::Bool, I0::UnitRange{Int}, I::Union(Int,UnitRange{Int})...)
N = length(I)
quote
f0 = first(I0)
l0 = length(I0)
l0 == 0 && return B
@nexprs $N d->(length(I[d]) == 0 && return B)

gap_lst_1 = 0
@nexprs N d->(gap_lst_{d+1} = length(I_d))
stride = 1
ind = f0
@nexprs N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I_d) - 1)
gap_lst_{d+1} *= stride
end
gap_lst_1 = 0
@nexprs $N d->(gap_lst_{d+1} = length(I[d]))
stride = 1
ind = f0
@nexprs $N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I[d]) - 1)
gap_lst_{d+1} *= stride
end

@nloops(N, i, d->I_d,
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
@nloops($N, i, d->I[d],
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
begin # BODY
fill_chunks!(B.chunks, x, ind, l0)
end)
end)

return B
return B
end
end


Expand Down

0 comments on commit 3477f04

Please sign in to comment.