Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Parallel(+, f)(x, y, z) to work like broadcasting, and enable Chain(identity, Parallel(+, f))(x, y, z) #2393

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```

A chain may be called with multiple arguments, which is equivalent to calling it
with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref)
to mean the same as several arguments:

```jldoctest
julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple
(1, 2, 3)
nothing

julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2
x = (4, 5)
25.25
```

For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
Expand All @@ -46,9 +60,10 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@layer :expand Chain # the + opts-in to container-style pretty-printing
@layer :expand Chain # the option :expand opts-in to container-style pretty-printing

(c::Chain)(x) = _applychain(c.layers, x)
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))

@generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N}
symbols = vcat(:x, [gensym() for _ in 1:N])
Expand All @@ -68,6 +83,7 @@ end
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
_show_layers(io, c.layers)
Expand Down Expand Up @@ -475,8 +491,11 @@ end
Create a layer which passes an input array to each path in
`layers`, before reducing the output with `connection`.

Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Obeys the similar rules to broadcasting:
* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
* With multiple inputs and multiple layers, one input is passed to each layer,
thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.

Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
Expand All @@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`.

# Examples

```jldoctest
julia> p = Parallel(+, abs2, sqrt);

julia> p(3, 4) # == 3^2 + √4, two functions two inputs
11.0

julia> p((3, 4)) # tuple is always splatted
11.0

julia> p(4) # == 4^2 + √4, one input used twice
18.0

julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs
1×3 Matrix{Float64}:
1.0 0.5 0.25
```

With Flux layers:

```jldoctest
julia> model = Chain(Dense(3 => 5),
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
Expand Down Expand Up @@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
layers::T
end

_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}}

Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return Parallel(connection, ())
Parallel(connection, layers)
end
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))

@layer :expand Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
@assert nl > 1 # dispatch handles nl==1 cases
nx = length(xs)
if (nl != nx)
throw(ArgumentError(lazy"Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
end
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)

function (m::Parallel)(xs...)
function (m::Parallel)(x, ys...)
xs = (x, ys...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...)
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
end

(m::_ParallelONE)(x, ys...) =
m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer

(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted
(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity

(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs"))

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down
23 changes: 17 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using Flux: activations
c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu))
@test c[1] == c[begin]
@test c[3] == c[end]

@test Chain(identity)(1,2,3) == (1,2,3) # multiple args become a tuple
end

@testset "Activations" begin
Expand Down Expand Up @@ -228,17 +230,20 @@ using Flux: activations
end

@testset "concat size" begin
input = randn(10, 2)
input = randn32(10, 2)
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
@test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4)
end

@testset "vararg input" begin
inputs = randn(10), randn(5), randn(4)
inputs = randn32(10), randn32(5), randn32(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
@test Parallel(+, sin, cos)(pi/2) ≈ 1
@test Parallel(+, sin, cos)(pi/2) ≈ 1 # one input, several layers
@test Parallel(/, abs)(3, -4) ≈ 3/4 # one layer, several inputs
@test Parallel(/, abs)((3, -4)) ≈ 3/4
@test Parallel(/; f=abs)(3, -4) ≈ 3/4
end

@testset "named access" begin
Expand All @@ -256,9 +261,13 @@ using Flux: activations
end

@testset "trivial cases" begin
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
@test Parallel(hcat)(1) == hcat()
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
# zero inputs, always an error
@test_throws ArgumentError Parallel(hcat)()
@test_throws ArgumentError Parallel(hcat, inv)()
@test_throws ArgumentError Parallel(hcat, inv, sqrt)()

# zero layers -- not useful... now made an error
@test_throws ArgumentError Parallel(hcat)
end

@testset "connection is called once" begin
Expand All @@ -270,6 +279,8 @@ using Flux: activations
@test CNT[] == 2
Parallel(f_cnt, sin)(1)
@test CNT[] == 3
Parallel(f_cnt, sin)(1,2,3)
@test CNT[] == 4
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
Expand Down
Loading