Skip to content

Commit

Permalink
Add functions with dims keyword argument (#518)
Browse files Browse the repository at this point in the history
* Add functions with `dims` keyword argument

Add (unexported) `Compat.accumulate`, `Compat.accumulate!`,
`Compat.all`, `Compat.any`, `Compat.cor`, `Compat.cov`,
`Compat.cumprod`, `Compat.cumprod!`, `Compat.cumsum`, `Compat.cumsum!`,
`Compat.findmax`, `Compat.findmin`, `Compat.mapreduce`,
`Compat.maximum`, `Compat.mean`, `Compat.median`, `Compat.minimum`,
`Compat.prod`, `Compat.reduce`, `Compat.sort`, `Compat.std`,
`Compat.sum`, `Compat.var`, and `Compat.varm` with `dims` keyword
argument.

* add a version check for some tests

* Use correction version bound for conditional tests.

* Remove some unnecessary at-evals

...at the cost of splitting definitions between two if-blocks.
  • Loading branch information
martinholters authored and fredrikekre committed Mar 21, 2018
1 parent 76a456b commit db56c16
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 8 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ Currently, the `@compat` macro supports the following syntaxes:

* `Compat.mv` and `Compat.cp` with `force` keyword argument ([#26069]).

* `Compat.accumulate`, `Compat.accumulate!`, `Compat.all`, `Compat.any`, `Compat.cor`,
`Compat.cov`, `Compat.cumprod`, `Compat.cumprod!`, `Compat.cumsum`, `Compat.cumsum!`,
`Compat.findmax`, `Compat.findmin`, `Compat.mapreduce`, `Compat.maximum`, `Compat.mean`,
`Compat.median`, `Compat.minimum`, `Compat.prod`, `Compat.reduce`, `Compat.sort`,
`Compat.std`, `Compat.sum`, `Compat.var`, and `Compat.varm` with `dims` keyword argument ([#25989],[#26369]).


## Renaming

* `Display` is now `AbstractDisplay` ([#24831]).
Expand Down Expand Up @@ -597,6 +604,7 @@ includes this fix. Find the minimum version from there.
[#25873]: https://github.com/JuliaLang/julia/issues/25873
[#25896]: https://github.com/JuliaLang/julia/issues/25896
[#25959]: https://github.com/JuliaLang/julia/issues/25959
[#25989]: https://github.com/JuliaLang/julia/issues/25989
[#25990]: https://github.com/JuliaLang/julia/issues/25990
[#25998]: https://github.com/JuliaLang/julia/issues/25998
[#26069]: https://github.com/JuliaLang/julia/issues/26069
Expand All @@ -605,5 +613,6 @@ includes this fix. Find the minimum version from there.
[#26156]: https://github.com/JuliaLang/julia/issues/26156
[#26283]: https://github.com/JuliaLang/julia/issues/26283
[#26316]: https://github.com/JuliaLang/julia/issues/26316
[#26369]: https://github.com/JuliaLang/julia/issues/26369
[#26436]: https://github.com/JuliaLang/julia/issues/26436
[#26442]: https://github.com/JuliaLang/julia/issues/26442
83 changes: 75 additions & 8 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ if VERSION < v"0.7.0-DEV.755"
# This is a hack to only add keyword signature that won't work on all julia versions.
# However, since we really only need to support a few (0.5, 0.6 and early 0.7) versions
# this should be good enough.
let Tf = typeof(cov), Tkw = Core.Core.kwftype(Tf)
let Tf = typeof(Base.cov), Tkw = Core.Core.kwftype(Tf)
@eval begin
@inline function _get_corrected(kws)
corrected = true
Expand All @@ -527,14 +527,14 @@ if VERSION < v"0.7.0-DEV.755"
return corrected::Bool
end
if VERSION >= v"0.6"
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = cov(x, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = Base.cov(x, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVector, Y::AbstractVector) =
cov(X, Y, _get_corrected(kws))
Base.cov(X, Y, _get_corrected(kws))
end
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractMatrix, vardim::Int) =
cov(x, vardim, _get_corrected(kws))
Base.cov(x, vardim, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVecOrMat, Y::AbstractVecOrMat,
vardim::Int) = cov(X, Y, vardim, _get_corrected(kws))
vardim::Int) = Base.cov(X, Y, vardim, _get_corrected(kws))
end
end
end
Expand Down Expand Up @@ -948,7 +948,7 @@ end
import Base: diagm
function diagm(kv::Pair...)
T = promote_type(map(x -> eltype(x.second), kv)...)
n = mapreduce(x -> length(x.second) + abs(x.first), max, kv)
n = Base.mapreduce(x -> length(x.second) + abs(x.first), max, kv)
A = zeros(T, n, n)
for p in kv
inds = diagind(A, p.first)
Expand Down Expand Up @@ -1255,7 +1255,7 @@ end

@inline function start(iter::CartesianIndices)
iterfirst, iterlast = first(iter), last(iter)
if any(map(>, iterfirst.I, iterlast.I))
if Base.any(map(>, iterfirst.I, iterlast.I))
return iterlast+1
end
iterfirst
Expand Down Expand Up @@ -1316,7 +1316,7 @@ end
@inline function Base.getindex(iter::LinearIndices{N,R}, I::Vararg{Int, N}) where {N,R}
dims = length.(iter.indices)
#without the inbounds, this is slower than Base._sub2ind(iter.indices, I...)
@inbounds result = reshape(1:prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
@inbounds result = reshape(1:Base.prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
return result
end
elseif VERSION < v"0.7.0-DEV.3395"
Expand Down Expand Up @@ -1709,6 +1709,73 @@ if VERSION < v"0.7.0-DEV.4585"
const lowercasefirst = lcfirst
end

if VERSION < v"0.7.0-DEV.4064"
for f in (:mean, :cumsum, :cumprod, :sum, :prod, :maximum, :minimum, :all, :any, :median)
@eval begin
$f(a::AbstractArray; dims=nothing) =
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
end
end
for f in (:sum, :prod, :maximum, :minimum, :all, :any, :accumulate)
@eval begin
$f(f, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.$f(f, a) : Base.$f(f, a, dims)
end
end
for f in (:findmax, :findmin)
@eval begin
$f(a::AbstractVector; dims=nothing) =
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
function $f(a::AbstractArray; dims=nothing)
vs, inds = dims===nothing ? Base.$f(a) : Base.$f(a, dims)
cis = CartesianIndices(a)
return (vs, map(i -> cis[i], inds))
end
end
end
for f in (:var, :std, :sort)
@eval begin
$f(a::AbstractArray; dims=nothing, kwargs...) =
dims===nothing ? Base.$f(a; kwargs...) : Base.$f(a, dims; kwargs...)
end
end
for f in (:cumsum!, :cumprod!)
@eval $f(out, a; dims=nothing) =
dims===nothing ? Base.$f(out, a) : Base.$f(out, a, dims)
end
end
if VERSION < v"0.7.0-DEV.4064"
varm(A::AbstractArray, m; dims=nothing, kwargs...) =
dims===nothing ? Base.varm(A, m; kwargs...) : Base.varm(A, m, dims; kwargs...)
if VERSION < v"0.7.0-DEV.755"
cov(a::AbstractMatrix; dims=1, corrected=true) = Base.cov(a, dims, corrected)
cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=1, corrected=true) =
Base.cov(a, b, dims, corrected)
else
cov(a::AbstractMatrix; dims=nothing, kwargs...) =
dims===nothing ? Base.cov(a; kwargs...) : Base.cov(a, dims; kwargs...)
cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing, kwargs...) =
dims===nothing ? Base.cov(a, b; kwargs...) : Base.cov(a, b, dims; kwargs...)
end
cor(a::AbstractMatrix; dims=nothing) = dims===nothing ? Base.cor(a) : Base.cor(a, dims)
cor(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing) =
dims===nothing ? Base.cor(a, b) : Base.cor(a, b, dims)
mapreduce(f, op, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.mapreduce(f, op, a) : Base.mapreducedim(f, op, a, dims)
mapreduce(f, op, v0, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.mapreduce(f, op, v0, a) : Base.mapreducedim(f, op, a, dims, v0)
reduce(op, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reduce(op, a) : Base.reducedim(op, a, dims)
reduce(op, v0, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reduce(op, v0, a) : Base.reducedim(op, a, dims, v0)
accumulate!(op, out, a; dims=nothing) =
dims===nothing ? Base.accumulate!(op, out, a) : Base.accumulate!(op, out, a, dims)
end
if VERSION < v"0.7.0-DEV.4534"
reverse(a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reverse(a) : Base.flipdim(a, dims)
end

include("deprecated.jl")

end # module Compat
134 changes: 134 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1535,4 +1535,138 @@ end
@test uppercasefirst("qwerty") == "Qwerty"
@test lowercasefirst("Qwerty") == "qwerty"

# 0.7.0-DEV.4064
# some tests are behind a version check below because Julia gave
# the wrong result between 0.7.0-DEV.3262 and 0.7.0-DEV.4646
# see https://github.com/JuliaLang/julia/issues/26488
Issue26488 = VERSION < v"0.7.0-DEV.3262" || VERSION >= v"0.7.0-DEV.4646"
@test Compat.mean([1 2; 3 4]) == 2.5
@test Compat.mean([1 2; 3 4], dims=1) == [2 3]
@test Compat.mean([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
@test Compat.sum([1 2; 3 4]) == 10
@test Compat.sum([1 2; 3 4], dims=1) == [4 6]
@test Compat.sum([1 2; 3 4], dims=2) == hcat([3; 7])
@test Compat.sum(x -> x+1, [1 2; 3 4]) == 14
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=1) == [6 8]
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=2) == hcat([5; 9])
@test Compat.prod([1 2; 3 4]) == 24
@test Compat.prod([1 2; 3 4], dims=1) == [3 8]
@test Compat.prod([1 2; 3 4], dims=2) == hcat([2; 12])
@test Compat.prod(x -> x+1, [1 2; 3 4]) == 120
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=1) == [8 15]
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=2) == hcat([6; 20])
@test Compat.maximum([1 2; 3 4]) == 4
@test Compat.maximum([1 2; 3 4], dims=1) == [3 4]
@test Compat.maximum([1 2; 3 4], dims=2) == hcat([2; 4])
@test Compat.maximum(x -> x+1, [1 2; 3 4]) == 5
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=1) == [4 5]
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=2) == hcat([3; 5])
@test Compat.minimum([1 2; 3 4]) == 1
@test Compat.minimum([1 2; 3 4], dims=1) == [1 2]
@test Compat.minimum([1 2; 3 4], dims=2) == hcat([1; 3])
@test Compat.minimum(x -> x+1, [1 2; 3 4]) == 2
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=1) == [2 3]
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=2) == hcat([2; 4])
@test Compat.all([true false; true false]) == false
@test Compat.all([true false; true false], dims=1) == [true false]
@test Compat.all([true false; true false], dims=2) == hcat([false; false])
@test Compat.all(isodd, [1 2; 3 4]) == false
@test Compat.all(isodd, [1 2; 3 4], dims=1) == [true false]
@test Compat.all(isodd, [1 2; 3 4], dims=2) == hcat([false; false])
@test Compat.any([true false; true false]) == true
@test Compat.any([true false; true false], dims=1) == [true false]
@test Compat.any([true false; true false], dims=2) == hcat([true; true])
@test Compat.any(isodd, [1 2; 3 4]) == true
@test Compat.any(isodd, [1 2; 3 4], dims=1) == [true false]
@test Compat.any(isodd, [1 2; 3 4], dims=2) == hcat([true; true])
@test Compat.findmax([3, 2, 7, 4]) == (7, 3)
@test Compat.findmax([3, 2, 7, 4], dims=1) == ([7], [3])
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test Compat.findmax([1 2; 3 4]) == (4, CartesianIndex(2, 2))
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test Compat.findmax([1 2; 3 4], dims=2) == (hcat([2; 4]), hcat([CartesianIndex(1, 2); CartesianIndex(2, 2)]))
@test Compat.findmin([3, 2, 7, 4]) == (2, 2)
@test Compat.findmin([3, 2, 7, 4], dims=1) == ([2], [2])
@test Compat.findmin([1 2; 3 4]) == (1, CartesianIndex(1, 1))
@test Compat.findmin([1 2; 3 4], dims=1) == ([1 2], [CartesianIndex(1, 1) CartesianIndex(1, 2)])
@test Compat.findmin([1 2; 3 4], dims=2) == (hcat([1; 3]), hcat([CartesianIndex(1, 1); CartesianIndex(2, 1)]))
@test Compat.varm([1 2; 3 4], -1) == 18
@test Compat.varm([1 2; 3 4], [-1 -2], dims=1) == [20 52]
@test Compat.varm([1 2; 3 4], [-1, -2], dims=2) == hcat([13, 61])
@test Compat.var([1 2; 3 4]) == 5/3
@test Compat.var([1 2; 3 4], dims=1) == [2 2]
@test Compat.var([1 2; 3 4], dims=2) == hcat([0.5, 0.5])
@test Compat.var([1 2; 3 4], corrected=false) == 1.25
@test Compat.var([1 2; 3 4], corrected=false, dims=1) == [1 1]
@test Compat.var([1 2; 3 4], corrected=false, dims=2) == hcat([0.25, 0.25])
@test Compat.std([1 2; 3 4]) == sqrt(5/3)
@test Compat.std([1 2; 3 4], dims=1) == [sqrt(2) sqrt(2)]
@test Compat.std([1 2; 3 4], dims=2) == hcat([sqrt(0.5), sqrt(0.5)])
@test Compat.std([1 2; 3 4], corrected=false) == sqrt(1.25)
@test Compat.std([1 2; 3 4], corrected=false, dims=1) == [sqrt(1) sqrt(1)]
@test Compat.std([1 2; 3 4], corrected=false, dims=2) == hcat([sqrt(0.25), sqrt(0.25)])
@test Compat.cov([1 2; 3 4]) == [2 2; 2 2]
@test Compat.cov([1 2; 3 4], dims=1) == [2 2; 2 2]
@test Compat.cov([1 2; 3 4], dims=2) == [0.5 0.5; 0.5 0.5]
@test Compat.cov([1 2; 3 4], [4; 5]) == hcat([1; 1])
@test Compat.cov([1 2; 3 4], [4; 5], dims=1) == hcat([1; 1])
@test Compat.cov([1 2; 3 4], [4; 5], dims=2) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=1) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=2) == hcat([0.25; 0.25])
@test Compat.cor([1 2; 3 4]) [1 1; 1 1]
@test Compat.cor([1 2; 3 4], dims=1) [1 1; 1 1]
@test Compat.cor([1 2; 3 4], dims=2) [1 1; 1 1]
@test Compat.cor([1 2; 3 4], [4; 5]) [1; 1]
@test Compat.cor([1 2; 3 4], [4; 5], dims=1) [1; 1]
@test Compat.cor([1 2; 3 4], [4; 5], dims=2) [1; 1]
@test Compat.median([1 2; 3 4]) == 2.5
@test Compat.median([1 2; 3 4], dims=1) == [2 3]
@test Compat.median([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
@test Compat.mapreduce(string, *, [1 2; 3 4]) == "1324"
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=1) == ["13" "24"]
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=2) == hcat(["12", "34"])
@test Compat.mapreduce(string, *, "z", [1 2; 3 4]) == "z1324"
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=1) == ["z13" "z24"]
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=2) == hcat(["z12", "z34"])
@test Compat.reduce(*, [1 2; 3 4]) == 24
@test Compat.reduce(*, [1 2; 3 4], dims=1) == [3 8]
@test Compat.reduce(*, [1 2; 3 4], dims=2) == hcat([2, 12])
@test Compat.reduce(*, 10, [1 2; 3 4]) == 240
@test Compat.reduce(*, 10, [1 2; 3 4], dims=1) == [30 80]
@test Compat.reduce(*, 10, [1 2; 3 4], dims=2) == hcat([20, 120])
@test Compat.sort([1, 2, 3, 4]) == [1, 2, 3, 4]
@test Compat.sort([1 2; 3 4], dims=1) == [1 2; 3 4]
@test Compat.sort([1 2; 3 4], dims=2) == [1 2; 3 4]
@test Compat.sort([1, 2, 3, 4], rev=true) == [4, 3, 2, 1]
@test Compat.sort([1 2; 3 4], rev=true, dims=1) == [3 4; 1 2]
@test Compat.sort([1 2; 3 4], rev=true, dims=2) == [2 1; 4 3]
@test Compat.accumulate(*, [1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.accumulate(*, [1 2; 3 4], dims=2) == [1 2; 3 12]
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
let b = zeros(2,2)
Compat.accumulate!(*, b, [1 2; 3 4], dims=1)
@test b == [1 2; 3 8]
Compat.accumulate!(*, b, [1 2; 3 4], dims=2)
@test b == [1 2; 3 12]
Compat.cumsum!(b, [1 2; 3 4], dims=1)
@test b == [1 2; 4 6]
Compat.cumsum!(b, [1 2; 3 4], dims=2)
@test b == [1 3; 3 7]
Compat.cumprod!(b, [1 2; 3 4], dims=1)
@test b == [1 2; 3 8]
Compat.cumprod!(b, [1 2; 3 4], dims=2)
@test b == [1 2; 3 12]
end
@test Compat.reverse([1, 2, 3, 4]) == [4, 3, 2, 1]
@test Compat.reverse([1 2; 3 4], dims=1) == [3 4; 1 2]
@test Compat.reverse([1 2; 3 4], dims=2) == [2 1; 4 3]

nothing

0 comments on commit db56c16

Please sign in to comment.