diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 254f06db0c..3c615ae06d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` +A chain may be called with multiple arguments, which is equivalent to calling it +with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref) +to mean the same as several arguments: + +```jldoctest +julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple +(1, 2, 3) +nothing + +julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2 +x = (4, 5) +25.25 +``` + For large models, there is a special type-unstable path which can reduce compilation times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. This feature is somewhat experimental, beware! @@ -46,9 +60,10 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@layer :expand Chain # the + opts-in to container-style pretty-printing +@layer :expand Chain # the option :expand opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) +(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...)) @generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N} symbols = vcat(:x, [gensym() for _ in 1:N]) @@ -68,6 +83,7 @@ end Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) + function Base.show(io::IO, c::Chain) print(io, "Chain(") _show_layers(io, c.layers) @@ -475,8 +491,11 @@ end Create a layer which passes an input array to each path in `layers`, before reducing the output with `connection`. -Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. -If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. +Obeys the similar rules to broadcasting: +* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. +* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`. +* With multiple inputs and multiple layers, one input is passed to each layer, + thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. These can be accessed by indexing: `m[1] == m[:name]` is the first layer. @@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`. # Examples +```jldoctest +julia> p = Parallel(+, abs2, sqrt); + +julia> p(3, 4) # == 3^2 + √4, two functions two inputs +11.0 + +julia> p((3, 4)) # tuple is always splatted +11.0 + +julia> p(4) # == 4^2 + √4, one input used twice +18.0 + +julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs +1×3 Matrix{Float64}: + 1.0 0.5 0.25 +``` + +With Flux layers: + ```jldoctest julia> model = Chain(Dense(3 => 5), Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))), @@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}} layers::T end +_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}} + Parallel(connection, layers...) = Parallel(connection, layers) function Parallel(connection; kw...) layers = NamedTuple(kw) if :layers in keys(layers) || :connection in keys(layers) throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) end - isempty(layers) && return Parallel(connection, ()) Parallel(connection, layers) end +Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) = + throw(ArgumentError("cannot construct a Parallel layer with no sub-layers")) @layer :expand Parallel -(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) -(m::Parallel)(xs::Tuple) = m(xs...) +(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument function _parallel_check(layers, xs) nl = length(layers) - nx = length(xs) + @assert nl > 1 # dispatch handles nl==1 cases + nx = length(xs) if (nl != nx) - throw(ArgumentError(lazy"Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs")) + throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs")) end end ChainRulesCore.@non_differentiable _parallel_check(nl, nx) -function (m::Parallel)(xs...) +function (m::Parallel)(x, ys...) + xs = (x, ys...) _parallel_check(m.layers, xs) - m.connection(map(|>, xs, Tuple(m.layers))...) + m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers end +(m::_ParallelONE)(x, ys...) = + m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer + +(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted +(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity + +(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs")) + Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 95da13f0c9..8e33340611 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -35,6 +35,8 @@ using Flux: activations c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu)) @test c[1] == c[begin] @test c[3] == c[end] + + @test Chain(identity)(1,2,3) == (1,2,3) # multiple args become a tuple end @testset "Activations" begin @@ -228,17 +230,20 @@ using Flux: activations end @testset "concat size" begin - input = randn(10, 2) + input = randn32(10, 2) @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) @test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4) end @testset "vararg input" begin - inputs = randn(10), randn(5), randn(4) + inputs = randn32(10), randn32(5), randn32(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) @test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs - @test Parallel(+, sin, cos)(pi/2) ≈ 1 + @test Parallel(+, sin, cos)(pi/2) ≈ 1 # one input, several layers + @test Parallel(/, abs)(3, -4) ≈ 3/4 # one layer, several inputs + @test Parallel(/, abs)((3, -4)) ≈ 3/4 + @test Parallel(/; f=abs)(3, -4) ≈ 3/4 end @testset "named access" begin @@ -256,9 +261,13 @@ using Flux: activations end @testset "trivial cases" begin - @test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple - @test Parallel(hcat)(1) == hcat() - @test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once. + # zero inputs, always an error + @test_throws ArgumentError Parallel(hcat)() + @test_throws ArgumentError Parallel(hcat, inv)() + @test_throws ArgumentError Parallel(hcat, inv, sqrt)() + + # zero layers -- not useful... now made an error + @test_throws ArgumentError Parallel(hcat) end @testset "connection is called once" begin @@ -270,6 +279,8 @@ using Flux: activations @test CNT[] == 2 Parallel(f_cnt, sin)(1) @test CNT[] == 3 + Parallel(f_cnt, sin)(1,2,3) + @test CNT[] == 4 end # Ref https://github.com/FluxML/Flux.jl/issues/1673