Skip to content

Commit

Permalink
Simplify trainable, functor and Parallel (#1862)
Browse files Browse the repository at this point in the history
* simple functor Chain

* simplify Maxout

* fix show as a result

* trainable always a NamedTuple

* Parallel: delete trainable, call combiner once

* fixup

* fix tests for Flux.modules
  • Loading branch information
mcabbott authored Feb 5, 2022
1 parent 841afe7 commit 9b21e2c
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 94 deletions.
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,


# v0.13 deprecations
@deprecate Maxout(layers::Tuple) Maxout(layers...)
67 changes: 34 additions & 33 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,30 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
"""
struct Chain{T}
struct Chain{T<:Union{Tuple, NamedTuple}}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
end

Chain(xs...) = Chain(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return Chain(())
Chain(values(kw))
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
@functor Chain

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

(c::Chain)(x) = applychain(Tuple(c.layers), x)

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
Expand Down Expand Up @@ -246,29 +247,23 @@ julia> Flux.outputsize(m3, (5, 11))
(7, 11)
```
"""
struct Maxout{FS<:Tuple}
over::FS
Maxout(layers...) = new{typeof(layers)}(layers)
end

function Maxout(f::Function, n_alts::Integer)
over = Tuple(f() for _ in 1:n_alts)
return Maxout(over...)
struct Maxout{T<:Tuple}
layers::T
end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
# even with Zygote. See #698 and #1794
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.layers)
end

trainable(mo::Maxout) = mo.over

function Base.show(io::IO, mo::Maxout)
print(io, "Maxout(")
_show_layers(io, mo.over)
_show_layers(io, mo.layers)
print(io, ")")
end

Expand Down Expand Up @@ -415,8 +410,8 @@ end
Create a `Parallel` layer that passes an input array to each path in
`layers`, before reducing the output with `connection`.
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
Expand Down Expand Up @@ -451,7 +446,7 @@ julia> model2[:β] == model2[2]
true
```
"""
struct Parallel{F, T}
struct Parallel{F, T<:Union{Tuple, NamedTuple}}
connection::F
layers::T
end
Expand All @@ -461,25 +456,31 @@ function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
elseif isempty(layers)
Parallel(connection, ())
end
isempty(layers) && return Parallel(connection, ())
Parallel(connection, layers)
end

@functor Parallel

(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
function (m::Parallel)(xs...)
nl = length(m.layers)
nx = length(xs)
if nl != nx
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
m.connection(map(|>, xs, Tuple(m.layers))...)
end

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))

trainable(m::Parallel) = (m.connection, m.layers...)

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
Expand Down
12 changes: 5 additions & 7 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ function Dropout(p; dims=:, rng = rng_from_array())
end

@functor Dropout

trainable(a::Dropout) = ()
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
Expand Down Expand Up @@ -122,8 +121,7 @@ AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout

trainable(a::AlphaDropout) = ()
trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
Expand Down Expand Up @@ -301,7 +299,7 @@ function BatchNorm(chs::Int, λ=identity;
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
Expand Down Expand Up @@ -377,7 +375,7 @@ function InstanceNorm(chs::Int, λ=identity;
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
Expand Down Expand Up @@ -439,7 +437,7 @@ mutable struct GroupNorm{F,V,N,W}
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function (m::Recur)(x)
end

@functor Recur
trainable(a::Recur) = (a.cell,)
trainable(a::Recur) = (; cell = a.cell)

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

Expand Down
7 changes: 6 additions & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
children = trainable(obj)
children = _show_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
Expand Down Expand Up @@ -48,6 +48,11 @@ _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
_show_leaflike(::Diagonal) = true # appears inside LayerNorm

_show_children(x) = trainable(x) # except for layers which hide their Tuple:
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
Expand Down
7 changes: 6 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,15 +775,20 @@ Chain(
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
julia> Flux.modules(m2)
5-element Vector{Any}:
7-element Vector{Any}:
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64, 10) # 650 parameters
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
julia> L2(m2) isa Float32
true
```
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
Expand Down
28 changes: 28 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import Flux: activations
@test m[:first] == m[1]
@test m[1:2] == m

@test m == m
@test m == fmap(identity, m) # does not forget names

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name
end

Expand Down Expand Up @@ -202,14 +205,39 @@ import Flux: activations
inputs = randn(10), randn(5), randn(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
@test Parallel(+, sin, cos)(pi/2) 1
end

@testset "named access" begin
m = Parallel(hcat, one = Dense(10, 10), two = identity)
@test m[1] == m[:one]
@test m[1:2] == m

@test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names
@test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity)

@test m == fmap(identity, m) # does not forget names

@test Parallel(vcat, x = log)(1) == [0]
@test Parallel(vcat, log)(1) == [0]
end

@testset "trivial cases" begin
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
@test Parallel(hcat)(1) == hcat()
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
end

@testset "connection is called once" begin
CNT = Ref(0)
f_cnt = (x...) -> (CNT[]+=1; +(x...))
Parallel(f_cnt, sin, cos, tan)(1)
@test CNT[] == 1
Parallel(f_cnt, sin, cos, tan)(1,2,3)
@test CNT[] == 2
Parallel(f_cnt, sin)(1)
@test CNT[] == 3
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
Expand Down
85 changes: 44 additions & 41 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,58 @@ using CUDA

Random.seed!(0)

@testset "Utils" begin
include("utils.jl")
end
@testset verbose=true "Flux.jl" begin

@testset "Onehot" begin
include("onehot.jl")
end
@testset "Utils" begin
include("utils.jl")
end

@testset "Optimise" begin
include("optimise.jl")
end
@testset "Onehot" begin
include("onehot.jl")
end

@testset "Data" begin
include("data.jl")
end
@testset "Optimise" begin
include("optimise.jl")
end

@testset "Losses" begin
include("losses.jl")
include("ctc.jl")
CUDA.functional() && include("ctc-gpu.jl")
end
@testset "Data" begin
include("data.jl")
end

@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
include("layers/show.jl")
end
@testset "Losses" begin
include("losses.jl")
include("ctc.jl")
CUDA.functional() && include("ctc-gpu.jl")
end

@testset "outputsize" begin
using Flux: outputsize
include("outputsize.jl")
end
@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
include("layers/show.jl")
end

@testset "CUDA" begin
if CUDA.functional()
include("cuda/runtests.jl")
else
@warn "CUDA unavailable, not testing GPU support"
@testset "outputsize" begin
using Flux: outputsize
include("outputsize.jl")
end

@testset "CUDA" begin
if CUDA.functional()
include("cuda/runtests.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
end

@static if VERSION == v"1.6"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
@static if VERSION == v"1.6"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
end
end
Loading

0 comments on commit 9b21e2c

Please sign in to comment.