Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 67 additions & 70 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@ Gives a reinterpreted view (of element type T) of the underlying array (of eleme
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
the first dimension.
"""
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N},IsScalar} <: AbstractArray{T, N}
parent::A
readable::Bool
writable::Bool
global reinterpret
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
function throwbits(S::Type, T::Type, U::Type)
@_noinline_meta
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
end
function throwsize0(::Type{S}, ::Type{T})
function throwsize0(S::Type, T::Type)
@_noinline_meta
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
end
function thrownonint(::Type{S}, ::Type{T}, dim)
function thrownonint(S::Type, T::Type, dim)
@_noinline_meta
throw(ArgumentError("""
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
The resulting array would have non-integral first dimension.
"""))
end
function throwaxes1(::Type{S}, ::Type{T}, ax1)
function throwaxes1(S::Type, T::Type, ax1)
@_noinline_meta
throw(ArgumentError("cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first."))
end
Expand All @@ -41,7 +41,7 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
end
readable = array_subpadding(T, S)
writable = array_subpadding(S, T)
new{T, N, S, A}(a, readable, writable)
new{T, N, S, A, sizeof(T) == sizeof(S) && fieldcount(T) + fieldcount(S) == 0}(a, readable, writable)
end
end

Expand Down Expand Up @@ -139,34 +139,32 @@ end

@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)

@inline @propagate_inbounds _getindex_ra(a::ReinterpretArray{T,N,S,A,true}, i1::Int, tailinds::TT) where {T,N,S,A<:AbstractArray{S, N},TT} =
reinterpret(T, a.parent[i1, tailinds...])

@inline @propagate_inbounds function _getindex_ra(a::ReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return reinterpret(T, a.parent[i1, tailinds...])
else
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
_memcpy!(tptr + nbytes_copied, sptr + sidx, nb)
nbytes_copied += nb
sidx = 0
i += 1
end
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
_memcpy!(tptr + nbytes_copied, sptr + sidx, nb)
nbytes_copied += nb
sidx = 0
i += 1
end
return t[]
end
return t[]
end


Expand All @@ -187,47 +185,46 @@ end
_setindex_ra!(a, v, inds[1], tail(inds))
end


@inline @propagate_inbounds _setindex_ra!(a::ReinterpretArray{T,N,S,A,true}, v, i1::Int, tailinds::TT) where {T,N,S,A<:AbstractArray{S, N},TT} =
setindex!(a.parent, reinterpret(S, convert(T, v)::T), i1, tailinds...)

@inline @propagate_inbounds function _setindex_ra!(a::ReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
else
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S) - sidx, sizeof(T))
_memcpy!(sptr + sidx, tptr, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
a.parent[ind_start + i, tailinds...] = s[]
end
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S) - sidx, sizeof(T))
_memcpy!(sptr + sidx, tptr, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
a.parent[ind_start + i, tailinds...] = s[]
end
end
return a
Expand Down