Skip to content

Commit

Permalink
allow arbitrary functions in map
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed May 7, 2021
1 parent 3a26e40 commit 25fc14e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
5 changes: 2 additions & 3 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,5 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} =
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)

## ReshapedArray eltype conversion

map(::Type{T}, R::ReshapedArray) where {T} = ReshapedArray(map(T, parent(R)), R.dims, R.mi)
## ReshapedArrays may forward a mapped function to the parent
map(f, R::ReshapedArray) = ReshapedArray(map(f, parent(R)), R.dims, R.mi)
19 changes: 15 additions & 4 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,9 @@ end
@test @inferred(reduce(vcat, ([10.0], [20.0], Bool[]))) == [10.0, 20.0]
end

@testset "eltype conversion for ReshapedArray" begin
@testset "map on a ReshapedArray (PR #40678)" begin
# Ranges have special map methods for eltype conversion
# reshaped ranges may utilize these
r = 1:4
R = reshape(r, 2, 2)
for T in [Float64, BigInt]
Expand All @@ -1328,11 +1330,20 @@ end
@test parent(S) isa AbstractRange
end

# a more complicated case
# in general for Arrays, map(f, R) and map(x -> f(x), R) should behave identically
# and the result should match map(f, collect(R))
R = reshape(reinterpret(Float64, ComplexF64[i for i = 1:4]), 1, :)
for T in [Int, BigInt]
S = map(T, R)
C = collect(R)
for T in [Int, BigInt], f in [T, x -> T(x)]
S = map(f, R)
@test eltype(S) == T
@test S == R
@test S == map(f, C)
end

R = reshape(1:4, 4, 1)
for f in [CartesianIndex, x -> CartesianIndex(x)]
RR = map(f, R)
@test RR == map(f, collect(R))
end
end

0 comments on commit 25fc14e

Please sign in to comment.