Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions with dims keyword argument #518

Merged
merged 4 commits into from
Mar 21, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ Currently, the `@compat` macro supports the following syntaxes:

* `Compat.mv` and `Compat.cp` with `force` keyword argument ([#26069]).

* `Compat.accumulate`, `Compat.accumulate!`, `Compat.all`, `Compat.any`, `Compat.cor`,
`Compat.cov`, `Compat.cumprod`, `Compat.cumprod!`, `Compat.cumsum`, `Compat.cumsum!`,
`Compat.findmax`, `Compat.findmin`, `Compat.mapreduce`, `Compat.maximum`, `Compat.mean`,
`Compat.median`, `Compat.minimum`, `Compat.prod`, `Compat.reduce`, `Compat.sort`,
`Compat.std`, `Compat.sum`, `Compat.var`, and `Compat.varm` with `dims` keyword argument ([#25989],[#26369]).


## Renaming

* `Display` is now `AbstractDisplay` ([#24831]).
Expand Down Expand Up @@ -592,12 +599,14 @@ includes this fix. Find the minimum version from there.
[#25873]: https://github.com/JuliaLang/julia/issues/25873
[#25896]: https://github.com/JuliaLang/julia/issues/25896
[#25959]: https://github.com/JuliaLang/julia/issues/25959
[#25989]: https://github.com/JuliaLang/julia/issues/25989
[#25990]: https://github.com/JuliaLang/julia/issues/25990
[#25998]: https://github.com/JuliaLang/julia/issues/25998
[#26069]: https://github.com/JuliaLang/julia/issues/26069
[#26089]: https://github.com/JuliaLang/julia/issues/26089
[#26149]: https://github.com/JuliaLang/julia/issues/26149
[#26156]: https://github.com/JuliaLang/julia/issues/26156
[#26316]: https://github.com/JuliaLang/julia/issues/26316
[#26369]: https://github.com/JuliaLang/julia/issues/26369
[#26436]: https://github.com/JuliaLang/julia/issues/26436
[#26442]: https://github.com/JuliaLang/julia/issues/26442
82 changes: 74 additions & 8 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ if VERSION < v"0.7.0-DEV.755"
# This is a hack to only add keyword signature that won't work on all julia versions.
# However, since we really only need to support a few (0.5, 0.6 and early 0.7) versions
# this should be good enough.
let Tf = typeof(cov), Tkw = Core.Core.kwftype(Tf)
let Tf = typeof(Base.cov), Tkw = Core.Core.kwftype(Tf)
@eval begin
@inline function _get_corrected(kws)
corrected = true
Expand All @@ -527,14 +527,14 @@ if VERSION < v"0.7.0-DEV.755"
return corrected::Bool
end
if VERSION >= v"0.6"
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = cov(x, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = Base.cov(x, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVector, Y::AbstractVector) =
cov(X, Y, _get_corrected(kws))
Base.cov(X, Y, _get_corrected(kws))
end
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractMatrix, vardim::Int) =
cov(x, vardim, _get_corrected(kws))
Base.cov(x, vardim, _get_corrected(kws))
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVecOrMat, Y::AbstractVecOrMat,
vardim::Int) = cov(X, Y, vardim, _get_corrected(kws))
vardim::Int) = Base.cov(X, Y, vardim, _get_corrected(kws))
end
end
end
Expand Down Expand Up @@ -929,7 +929,7 @@ end
import Base: diagm
function diagm(kv::Pair...)
T = promote_type(map(x -> eltype(x.second), kv)...)
n = mapreduce(x -> length(x.second) + abs(x.first), max, kv)
n = Base.mapreduce(x -> length(x.second) + abs(x.first), max, kv)
A = zeros(T, n, n)
for p in kv
inds = diagind(A, p.first)
Expand Down Expand Up @@ -1236,7 +1236,7 @@ end

@inline function start(iter::CartesianIndices)
iterfirst, iterlast = first(iter), last(iter)
if any(map(>, iterfirst.I, iterlast.I))
if Base.any(map(>, iterfirst.I, iterlast.I))
return iterlast+1
end
iterfirst
Expand Down Expand Up @@ -1297,7 +1297,7 @@ end
@inline function Base.getindex(iter::LinearIndices{N,R}, I::Vararg{Int, N}) where {N,R}
dims = length.(iter.indices)
#without the inbounds, this is slower than Base._sub2ind(iter.indices, I...)
@inbounds result = reshape(1:prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
@inbounds result = reshape(1:Base.prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
return result
end
elseif VERSION < v"0.7.0-DEV.3395"
Expand Down Expand Up @@ -1690,6 +1690,72 @@ if VERSION < v"0.7.0-DEV.4585"
const lowercasefirst = lcfirst
end

if VERSION < v"0.7.0-DEV.4064"
for f in (:mean, :cumsum, :cumprod, :sum, :prod, :maximum, :minimum, :all, :any, :median)
@eval begin
$f(a::AbstractArray; dims=nothing) =
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
end
end
for f in (:sum, :prod, :maximum, :minimum, :all, :any, :accumulate)
@eval begin
$f(f, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.$f(f, a) : Base.$f(f, a, dims)
end
end
for f in (:findmax, :findmin)
@eval begin
$f(a::AbstractVector; dims=nothing) =
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
function $f(a::AbstractArray; dims=nothing)
vs, inds = dims===nothing ? Base.$f(a) : Base.$f(a, dims)
cis = CartesianIndices(a)
return (vs, map(i -> cis[i], inds))
end
end
end
@eval varm(A::AbstractArray, m; dims=nothing, kwargs...) =
dims===nothing ? Base.varm(A, m; kwargs...) : Base.varm(A, m, dims; kwargs...)
for f in (:var, :std, :sort)
@eval begin
$f(a::AbstractArray; dims=nothing, kwargs...) =
dims===nothing ? Base.$f(a; kwargs...) : Base.$f(a, dims; kwargs...)
end
end
if VERSION < v"0.7.0-DEV.755"
@eval cov(a::AbstractMatrix; dims=1, corrected=true) = Base.cov(a, dims, corrected)
@eval cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=1, corrected=true) =
Base.cov(a, b, dims, corrected)
else
@eval cov(a::AbstractMatrix; dims=nothing, kwargs...) =
dims===nothing ? Base.cov(a; kwargs...) : Base.cov(a, dims; kwargs...)
@eval cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing, kwargs...) =
dims===nothing ? Base.cov(a, b; kwargs...) : Base.cov(a, b, dims; kwargs...)
end
@eval cor(a::AbstractMatrix; dims=nothing) =
dims===nothing ? Base.cor(a) : Base.cor(a, dims)
@eval cor(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing) =
dims===nothing ? Base.cor(a, b) : Base.cor(a, b, dims)
@eval mapreduce(f, op, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.mapreduce(f, op, a) : Base.mapreducedim(f, op, a, dims)
@eval mapreduce(f, op, v0, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.mapreduce(f, op, v0, a) : Base.mapreducedim(f, op, a, dims, v0)
@eval reduce(op, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reduce(op, a) : Base.reducedim(op, a, dims)
@eval reduce(op, v0, a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reduce(op, v0, a) : Base.reducedim(op, a, dims, v0)
@eval accumulate!(op, out, a; dims=nothing) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess these @eval could be removed? Only the ones in the loops over function symbols is needed AFAICT.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'll update.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, no, that gives

ERROR: LoadError: error compiling anonymous: error in method definition: function Base.accumulate! must be explicitly imported to be extended

There is no other use accumulate! though, so why doesn't it just create a new binding to a new function? BUT if I move the definition to its own if block, it does work without the @eval. Strange...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I tried the same, lets just leave it as is I guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, seems to be some Julia issue on 0.6, but is fixed on master:

julia> module M1
           if true
               identity(x) = Base.identity(x)
           end
       end
M1

julia> module M2
           if true
               for f in (:a, :b); end
               identity(x) = Base.identity(x)
           end
       end
ERROR: error compiling anonymous: error in method definition: function Base.identity must be explicitly imported to be extended

dims===nothing ? Base.accumulate!(op, out, a) : Base.accumulate!(op, out, a, dims)
for f in (:cumsum!, :cumprod!)
@eval $f(out, a; dims=nothing) =
dims===nothing ? Base.$f(out, a) : Base.$f(out, a, dims)
end
end
if VERSION < v"0.7.0-DEV.4534"
@eval reverse(a::AbstractArray; dims=nothing) =
dims===nothing ? Base.reverse(a) : Base.flipdim(a, dims)
end

include("deprecated.jl")

end # module Compat
134 changes: 134 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1526,4 +1526,138 @@ end
@test uppercasefirst("qwerty") == "Qwerty"
@test lowercasefirst("Qwerty") == "qwerty"

# 0.7.0-DEV.4064
# some tests are behind a version check below because Julia gave
# the wrong result between 0.7.0-DEV.4064 and 0.7.0-DEV.4646
# see https://github.com/JuliaLang/julia/issues/26488
Issue26488 = VERSION < v"0.7.0-DEV.4064" || VERSION >= v"0.7.0-DEV.4646"
@test Compat.mean([1 2; 3 4]) == 2.5
@test Compat.mean([1 2; 3 4], dims=1) == [2 3]
@test Compat.mean([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
@test Compat.sum([1 2; 3 4]) == 10
@test Compat.sum([1 2; 3 4], dims=1) == [4 6]
@test Compat.sum([1 2; 3 4], dims=2) == hcat([3; 7])
@test Compat.sum(x -> x+1, [1 2; 3 4]) == 14
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=1) == [6 8]
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=2) == hcat([5; 9])
@test Compat.prod([1 2; 3 4]) == 24
@test Compat.prod([1 2; 3 4], dims=1) == [3 8]
@test Compat.prod([1 2; 3 4], dims=2) == hcat([2; 12])
@test Compat.prod(x -> x+1, [1 2; 3 4]) == 120
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=1) == [8 15]
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=2) == hcat([6; 20])
@test Compat.maximum([1 2; 3 4]) == 4
@test Compat.maximum([1 2; 3 4], dims=1) == [3 4]
@test Compat.maximum([1 2; 3 4], dims=2) == hcat([2; 4])
@test Compat.maximum(x -> x+1, [1 2; 3 4]) == 5
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=1) == [4 5]
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=2) == hcat([3; 5])
@test Compat.minimum([1 2; 3 4]) == 1
@test Compat.minimum([1 2; 3 4], dims=1) == [1 2]
@test Compat.minimum([1 2; 3 4], dims=2) == hcat([1; 3])
@test Compat.minimum(x -> x+1, [1 2; 3 4]) == 2
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=1) == [2 3]
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=2) == hcat([2; 4])
@test Compat.all([true false; true false]) == false
@test Compat.all([true false; true false], dims=1) == [true false]
@test Compat.all([true false; true false], dims=2) == hcat([false; false])
@test Compat.all(isodd, [1 2; 3 4]) == false
@test Compat.all(isodd, [1 2; 3 4], dims=1) == [true false]
@test Compat.all(isodd, [1 2; 3 4], dims=2) == hcat([false; false])
@test Compat.any([true false; true false]) == true
@test Compat.any([true false; true false], dims=1) == [true false]
@test Compat.any([true false; true false], dims=2) == hcat([true; true])
@test Compat.any(isodd, [1 2; 3 4]) == true
@test Compat.any(isodd, [1 2; 3 4], dims=1) == [true false]
@test Compat.any(isodd, [1 2; 3 4], dims=2) == hcat([true; true])
@test Compat.findmax([3, 2, 7, 4]) == (7, 3)
@test Compat.findmax([3, 2, 7, 4], dims=1) == ([7], [3])
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test Compat.findmax([1 2; 3 4]) == (4, CartesianIndex(2, 2))
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test Compat.findmax([1 2; 3 4], dims=2) == (hcat([2; 4]), hcat([CartesianIndex(1, 2); CartesianIndex(2, 2)]))
@test Compat.findmin([3, 2, 7, 4]) == (2, 2)
@test Compat.findmin([3, 2, 7, 4], dims=1) == ([2], [2])
@test Compat.findmin([1 2; 3 4]) == (1, CartesianIndex(1, 1))
@test Compat.findmin([1 2; 3 4], dims=1) == ([1 2], [CartesianIndex(1, 1) CartesianIndex(1, 2)])
@test Compat.findmin([1 2; 3 4], dims=2) == (hcat([1; 3]), hcat([CartesianIndex(1, 1); CartesianIndex(2, 1)]))
@test Compat.varm([1 2; 3 4], -1) == 18
@test Compat.varm([1 2; 3 4], [-1 -2], dims=1) == [20 52]
@test Compat.varm([1 2; 3 4], [-1, -2], dims=2) == hcat([13, 61])
@test Compat.var([1 2; 3 4]) == 5/3
@test Compat.var([1 2; 3 4], dims=1) == [2 2]
@test Compat.var([1 2; 3 4], dims=2) == hcat([0.5, 0.5])
@test Compat.var([1 2; 3 4], corrected=false) == 1.25
@test Compat.var([1 2; 3 4], corrected=false, dims=1) == [1 1]
@test Compat.var([1 2; 3 4], corrected=false, dims=2) == hcat([0.25, 0.25])
@test Compat.std([1 2; 3 4]) == sqrt(5/3)
@test Compat.std([1 2; 3 4], dims=1) == [sqrt(2) sqrt(2)]
@test Compat.std([1 2; 3 4], dims=2) == hcat([sqrt(0.5), sqrt(0.5)])
@test Compat.std([1 2; 3 4], corrected=false) == sqrt(1.25)
@test Compat.std([1 2; 3 4], corrected=false, dims=1) == [sqrt(1) sqrt(1)]
@test Compat.std([1 2; 3 4], corrected=false, dims=2) == hcat([sqrt(0.25), sqrt(0.25)])
@test Compat.cov([1 2; 3 4]) == [2 2; 2 2]
@test Compat.cov([1 2; 3 4], dims=1) == [2 2; 2 2]
@test Compat.cov([1 2; 3 4], dims=2) == [0.5 0.5; 0.5 0.5]
@test Compat.cov([1 2; 3 4], [4; 5]) == hcat([1; 1])
@test Compat.cov([1 2; 3 4], [4; 5], dims=1) == hcat([1; 1])
@test Compat.cov([1 2; 3 4], [4; 5], dims=2) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=1) == hcat([0.5; 0.5])
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=2) == hcat([0.25; 0.25])
@test Compat.cor([1 2; 3 4]) ≈ [1 1; 1 1]
@test Compat.cor([1 2; 3 4], dims=1) ≈ [1 1; 1 1]
@test Compat.cor([1 2; 3 4], dims=2) ≈ [1 1; 1 1]
@test Compat.cor([1 2; 3 4], [4; 5]) ≈ [1; 1]
@test Compat.cor([1 2; 3 4], [4; 5], dims=1) ≈ [1; 1]
@test Compat.cor([1 2; 3 4], [4; 5], dims=2) ≈ [1; 1]
@test Compat.median([1 2; 3 4]) == 2.5
@test Compat.median([1 2; 3 4], dims=1) == [2 3]
@test Compat.median([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
@test Compat.mapreduce(string, *, [1 2; 3 4]) == "1324"
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=1) == ["13" "24"]
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=2) == hcat(["12", "34"])
@test Compat.mapreduce(string, *, "z", [1 2; 3 4]) == "z1324"
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=1) == ["z13" "z24"]
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=2) == hcat(["z12", "z34"])
@test Compat.reduce(*, [1 2; 3 4]) == 24
@test Compat.reduce(*, [1 2; 3 4], dims=1) == [3 8]
@test Compat.reduce(*, [1 2; 3 4], dims=2) == hcat([2, 12])
@test Compat.reduce(*, 10, [1 2; 3 4]) == 240
@test Compat.reduce(*, 10, [1 2; 3 4], dims=1) == [30 80]
@test Compat.reduce(*, 10, [1 2; 3 4], dims=2) == hcat([20, 120])
@test Compat.sort([1, 2, 3, 4]) == [1, 2, 3, 4]
@test Compat.sort([1 2; 3 4], dims=1) == [1 2; 3 4]
@test Compat.sort([1 2; 3 4], dims=2) == [1 2; 3 4]
@test Compat.sort([1, 2, 3, 4], rev=true) == [4, 3, 2, 1]
@test Compat.sort([1 2; 3 4], rev=true, dims=1) == [3 4; 1 2]
@test Compat.sort([1 2; 3 4], rev=true, dims=2) == [2 1; 4 3]
@test Compat.accumulate(*, [1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.accumulate(*, [1 2; 3 4], dims=2) == [1 2; 3 12]
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
let b = zeros(2,2)
Compat.accumulate!(*, b, [1 2; 3 4], dims=1)
@test b == [1 2; 3 8]
Compat.accumulate!(*, b, [1 2; 3 4], dims=2)
@test b == [1 2; 3 12]
Compat.cumsum!(b, [1 2; 3 4], dims=1)
@test b == [1 2; 4 6]
Compat.cumsum!(b, [1 2; 3 4], dims=2)
@test b == [1 3; 3 7]
Compat.cumprod!(b, [1 2; 3 4], dims=1)
@test b == [1 2; 3 8]
Compat.cumprod!(b, [1 2; 3 4], dims=2)
@test b == [1 2; 3 12]
end
@test Compat.reverse([1, 2, 3, 4]) == [4, 3, 2, 1]
@test Compat.reverse([1 2; 3 4], dims=1) == [3 4; 1 2]
@test Compat.reverse([1 2; 3 4], dims=2) == [2 1; 4 3]

nothing