Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReinterpretArray and IndexStyle #27909

Merged
merged 1 commit into from
Jul 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
end
end

IndexStyle(a::ReinterpretArray) = IndexStyle(a.parent)

parent(a::ReinterpretArray) = a.parent
dataids(a::ReinterpretArray) = dataids(a.parent)

Expand All @@ -51,11 +53,25 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
_getindex_ra(a, inds[1], tail(inds))
end

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, i::Int) where {T,N,S}
if isa(IndexStyle(a), IndexLinear)
return _getindex_ra(a, i, ())
end
# Convert to full indices here, to avoid needing multiple conversions in
# the loop in _getindex_ra
inds = _to_subscript_indices(a, i)
_getindex_ra(a, inds[1], tail(inds))
end

@inline @propagate_inbounds function _getindex_ra(a::ReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return reinterpret(T, a.parent[inds...])
return reinterpret(T, a.parent[i1, tailinds...])
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
GC.@preserve t s begin
Expand All @@ -67,7 +83,7 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
Expand All @@ -81,16 +97,29 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
end
end


@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
_setindex_ra!(a, v, inds[1], tail(inds))
end

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, i::Int) where {T,N,S}
if isa(IndexStyle(a), IndexLinear)
return _setindex_ra!(a, v, i, ())
end
inds = _to_subscript_indices(a, i)
_setindex_ra!(a, v, inds[1], tail(inds))
end

@inline @propagate_inbounds function _setindex_ra!(a::ReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), inds...)
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
Expand All @@ -101,13 +130,13 @@ end
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tail(inds)...]
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
Expand All @@ -118,19 +147,19 @@ end
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
a.parent[ind_start + i, tailinds...] = s[]
end
end
end
Expand Down
19 changes: 19 additions & 0 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,22 @@ let A = collect(reshape(1:20, 5, 4))
@test view(R, :, :) isa StridedArray
@test reshape(R, :) isa StridedArray
end

# IndexStyle
let a = fill(1.0, 5, 3)
r = reinterpret(Int64, a)
@test @inferred(IndexStyle(r)) == IndexLinear()
fill!(r, 2)
@test all(a .=== reinterpret(Float64, [Int64(2)])[1])
@test all(r .=== Int64(2))
r = reinterpret(Int32, a)
@test @inferred(IndexStyle(r)) == IndexLinear()
fill!(r, 3)
@test all(a .=== reinterpret(Float64, [(Int32(3), Int32(3))])[1])
@test all(r .=== Int32(3))
r = reinterpret(Int64, view(a, 1:2:5, :))
@test @inferred(IndexStyle(r)) == IndexCartesian()
fill!(r, 4)
@test all(a[1:2:5,:] .=== reinterpret(Float64, [Int64(4)])[1])
@test all(r .=== Int64(4))
end