Skip to content

Commit

Permalink
Merge #1674
Browse files Browse the repository at this point in the history
1674: Generalise Parallel forwards pass r=DhairyaLGandhi a=DhairyaLGandhi

Fix #1673

We should also do a sweep of other layers that may be more restrictive due to excessive typing. By default there should be no typing on the input, or even necessarily on the parameters.

The `(::Parallel)(::Tuple)` case is interesting since the current version assumes that all the elements need to spread out, but  layers may output multiple things as part of their forward pass, and ingest multiple arguments too. This is also the case in the Weave model for example. So we may want to treat tuples as invariant and treat them like a single entity. Currently, the API to pass tuples could be improved since passing tuples to layers would need an extra layer of nesting for example: `p(((x,),))`


Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
bors[bot] and DhairyaLGandhi authored Aug 15, 2021
2 parents 29a96b9 + 97bc446 commit 6168692
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ end

@functor Parallel

(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
(m::Parallel)(xs::Tuple) = m(xs...)

Base.getindex(m::Parallel, i) = m.layers[i]
Expand Down
31 changes: 31 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,37 @@ import Flux: activations
@test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names
@test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity)
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
@testset "Input domain" begin
struct Input
x
end

struct L1
x
end
(l::L1)(x) = l.x * x
Flux.@functor L1
Base.:*(a::AbstractArray, b::Input) = a * b.x

par = Parallel(+, L1(rand(Float32, 3,3)), L1(rand(Float32, 3,3)))
ip = Input(rand(Float32, 3,3))
ip2 = Input(rand(Float32, 3,3))

@test par(ip) par.layers[1](ip.x) + par.layers[2](ip.x)
@test par(ip, ip2) par.layers[1](ip.x) + par.layers[2](ip2.x)
gs = gradient((par, x...) -> sum(par(x...)), par, ip, ip2)
gs_reg = gradient(par, ip, ip2) do par, x, y
sum(par.layers[1](x.x) + par.layers[2](y.x))
end

for (a,b) in zip(gs[1].layers, gs_reg[1].layers)
@test a.x b.x
end
@test gs[2].x gs_reg[2].x
@test gs[3].x gs_reg[3].x
end
end

@testset "Embedding" begin
Expand Down

0 comments on commit 6168692

Please sign in to comment.