Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify trainable, functor and Parallel #1862

Merged
merged 7 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,


# v0.13 deprecations
@deprecate Maxout(layers::Tuple) Maxout(layers...)
67 changes: 34 additions & 33 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,30 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
"""
struct Chain{T}
struct Chain{T<:Union{Tuple, NamedTuple}}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
end

Chain(xs...) = Chain(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return Chain(())
Chain(values(kw))
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
@functor Chain

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

(c::Chain)(x) = applychain(Tuple(c.layers), x)

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
Expand Down Expand Up @@ -245,29 +246,23 @@ julia> Flux.outputsize(m3, (5, 11))
(7, 11)
```
"""
struct Maxout{FS<:Tuple}
over::FS
Maxout(layers...) = new{typeof(layers)}(layers)
end

function Maxout(f::Function, n_alts::Integer)
over = Tuple(f() for _ in 1:n_alts)
return Maxout(over...)
struct Maxout{T<:Tuple}
layers::T
end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
# even with Zygote. See #698 and #1794
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.layers)
end

trainable(mo::Maxout) = mo.over

function Base.show(io::IO, mo::Maxout)
print(io, "Maxout(")
_show_layers(io, mo.over)
_show_layers(io, mo.layers)
print(io, ")")
end

Expand Down Expand Up @@ -414,8 +409,8 @@ end
Create a `Parallel` layer that passes an input array to each path in
`layers`, before reducing the output with `connection`.

Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.

Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
Expand Down Expand Up @@ -450,7 +445,7 @@ julia> model2[:β] == model2[2]
true
```
"""
struct Parallel{F, T}
struct Parallel{F, T<:Union{Tuple, NamedTuple}}
connection::F
layers::T
end
Expand All @@ -460,25 +455,31 @@ function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
elseif isempty(layers)
Parallel(connection, ())
end
isempty(layers) && return Parallel(connection, ())
Parallel(connection, layers)
end

@functor Parallel

(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
function (m::Parallel)(xs...)
nl = length(m.layers)
nx = length(xs)
if nl != nx
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
m.connection(map(|>, xs, Tuple(m.layers))...)
end

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))

trainable(m::Parallel) = (m.connection, m.layers...)

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
Expand Down
12 changes: 5 additions & 7 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ function Dropout(p; dims=:, rng = rng_from_array())
end

@functor Dropout
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved

trainable(a::Dropout) = ()
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
Expand Down Expand Up @@ -122,8 +121,7 @@ AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout

trainable(a::AlphaDropout) = ()
trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
Expand Down Expand Up @@ -288,7 +286,7 @@ function BatchNorm(chs::Int, λ=identity;
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
Expand Down Expand Up @@ -364,7 +362,7 @@ function InstanceNorm(chs::Int, λ=identity;
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
Expand Down Expand Up @@ -426,7 +424,7 @@ mutable struct GroupNorm{F,V,N,W}
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function (m::Recur)(x)
end

@functor Recur
trainable(a::Recur) = (a.cell,)
trainable(a::Recur) = (; cell = a.cell)

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

Expand Down
7 changes: 6 additions & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
children = trainable(obj)
children = _show_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
Expand Down Expand Up @@ -48,6 +48,11 @@ _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
_show_leaflike(::Diagonal) = true # appears inside LayerNorm

_show_children(x) = trainable(x) # except for layers which hide their Tuple:
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
Expand Down
7 changes: 6 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,15 +775,20 @@ Chain(
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.

julia> Flux.modules(m2)
5-element Vector{Any}:
7-element Vector{Any}:
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64, 10) # 650 parameters

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)

julia> L2(m2) isa Float32
true
```
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
Expand Down
28 changes: 28 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import Flux: activations
@test m[:first] == m[1]
@test m[1:2] == m

@test m == m
@test m == fmap(identity, m) # does not forget names

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name
end

Expand Down Expand Up @@ -202,14 +205,39 @@ import Flux: activations
inputs = randn(10), randn(5), randn(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
@test Parallel(+, sin, cos)(pi/2) ≈ 1
end

@testset "named access" begin
m = Parallel(hcat, one = Dense(10, 10), two = identity)
@test m[1] == m[:one]
@test m[1:2] == m

@test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names
@test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity)

@test m == fmap(identity, m) # does not forget names

@test Parallel(vcat, x = log)(1) == [0]
@test Parallel(vcat, log)(1) == [0]
end

@testset "trivial cases" begin
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
@test Parallel(hcat)(1) == hcat()
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
end

@testset "connection is called once" begin
CNT = Ref(0)
f_cnt = (x...) -> (CNT[]+=1; +(x...))
Parallel(f_cnt, sin, cos, tan)(1)
@test CNT[] == 1
Parallel(f_cnt, sin, cos, tan)(1,2,3)
@test CNT[] == 2
Parallel(f_cnt, sin)(1)
@test CNT[] == 3
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
Expand Down
85 changes: 44 additions & 41 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,58 @@ using CUDA

Random.seed!(0)

@testset "Utils" begin
include("utils.jl")
end
@testset verbose=true "Flux.jl" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR also adds an overall testset, so that all tests are run even if one fails near the start.


@testset "Onehot" begin
include("onehot.jl")
end
@testset "Utils" begin
include("utils.jl")
end

@testset "Optimise" begin
include("optimise.jl")
end
@testset "Onehot" begin
include("onehot.jl")
end

@testset "Data" begin
include("data.jl")
end
@testset "Optimise" begin
include("optimise.jl")
end

@testset "Losses" begin
include("losses.jl")
include("ctc.jl")
CUDA.functional() && include("ctc-gpu.jl")
end
@testset "Data" begin
include("data.jl")
end

@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
include("layers/show.jl")
end
@testset "Losses" begin
include("losses.jl")
include("ctc.jl")
CUDA.functional() && include("ctc-gpu.jl")
end

@testset "outputsize" begin
using Flux: outputsize
include("outputsize.jl")
end
@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
include("layers/show.jl")
end

@testset "CUDA" begin
if CUDA.functional()
include("cuda/runtests.jl")
else
@warn "CUDA unavailable, not testing GPU support"
@testset "outputsize" begin
using Flux: outputsize
include("outputsize.jl")
end

@testset "CUDA" begin
if CUDA.functional()
include("cuda/runtests.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
end

@static if VERSION == v"1.6"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
@static if VERSION == v"1.6"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
end
end
Loading