Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel edge-cases #1685

Closed
mcabbott opened this issue Aug 1, 2021 · 6 comments · Fixed by #1862
Closed

Parallel edge-cases #1685

mcabbott opened this issue Aug 1, 2021 · 6 comments · Fixed by #1862
Milestone

Comments

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2021

First, about vararg vs single-argument. I think this is the documented behaviour, using zip, but how surprising is this? Is its early stopping a feature or a footgun?

julia> Parallel(hcat, x->x.+1, x->x.+2, x->x.+3)([0])  # one argument for all mode
1×3 Matrix{Int64}:
 1  2  3

julia> Parallel(hcat, x->x.+1, x->x.+2, x->x.+3)([0], [0])  # vararg => zip mode, zip ignores 3rd layer 
1×2 Matrix{Int64}:
 1  2

julia> Parallel(hcat, x->x.+1, x->x.+2, x->x.+3)([0], [0], [0], [1000])  # zip ignores 4th argument
1×3 Matrix{Int64}:
 1  2  3

Second, from the docs it's not very clear whether "reducing the output with connection" means pairwise, or vararg. In fact it means pairwise, which is the same (if slightly less efficient) for +, vcat, etc.

julia> myplus(xs...) = begin println("+ ", length(xs)); +(xs...) end;

julia> Parallel(myplus, identity, identity, identity)([1]);
+ 2
+ 2

julia> Parallel(myplus, identity, identity, identity)([1], [1], [1]);
+ 2
+ 2

Third, it currently allows the construction of a layer with no sub-layers, but this cannot be called. Is this desirable to allow some automatic generation not not to produce errors in a trivial case, even if it cannot be called? Would it be better to error on construction? Or should calling this be given some meaning --- if connection is vararg, then possibly Parallel(myplus)([1]) == +([1])?

julia> Parallel(myplus)
Parallel(myplus, )

julia> ans([1])
ERROR: ArgumentError: reducing over an empty collection is not allowed

In fact, even for one sub-layer there are surprises. Probably this should run, but should it call the connection with one argument, or not? (At present, not).

julia> Parallel(myplus, identity)([1]);
ERROR: MethodError: no method matching iterate(::typeof(identity))

julia> Parallel(myplus, (identity,))([1])
1-element Vector{Int64}:
 1

julia> Parallel(myplus, Dense(1,1));  # constructor runs...

julia> Flux.trainable(Parallel(myplus, Dense(1,1)))  # ... but this causes `show` to fail
ERROR: MethodError: no method matching iterate(::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})

julia> Parallel(myplus, Dense(1,1))([1])  # same error on calling
ERROR: MethodError: no method matching iterate(::Dense

(Split off from discussion in #1681, which is really orthogonal.)

@ToucheSir
Copy link
Member

First, about vararg vs single-argument. I think this is the documented behaviour, using zip, but how surprising is this? Is its early stopping a feature or a footgun?

The contract for Parallel has 3 supported cases:

  1. N layers, 1 input
  2. N layers, N inputs
  3. 1 layer, N inputs

Anything else should probably be guarded against in the call method. We could use dispatch for this if we enforce Parallel.layers being an NTuple, but that would preclude the changes in #1681. A tradeoff I'm willing to make, but worth discussing beforehand.

Second, from the docs it's not very clear whether "reducing the output with connection" means pairwise, or vararg. In fact it means pairwise, which is the same (if slightly less efficient) for +, vcat, etc.

The pairwise part is inherited from SkipConnection. I'm ambivalent about which is better since common connections like [hv]cat are more optimized with reduce, but it would be good to document.

Third, it currently allows the construction of a layer with no sub-layers, but this cannot be called. Is this desirable to allow some automatic generation not not to produce errors in a trivial case, even if it cannot be called? Would it be better to error on construction? Or should calling this be given some meaning --- if connection is vararg, then possibly Parallel(myplus)([1]) == +([1])?
...
In fact, even for one sub-layer there are surprises. Probably this should run, but should it call the connection with one argument, or not? (At present, not).

This is just constructor confusion between the default and https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl#L414. There are a number of ways to resolve this. The shortest I found was to (as mentioned above) constrain T <: NTuple.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 1, 2021

constructor confusion between the default and

Oh now I see, thanks. As you say it seems simplest to always store a tuple, even for one element, perhaps just by moving that line to be an inner constructor. But NTuple demands uniform type, which would forbid e.g. (identity, Dense(...))

Anything else should probably be guarded against in the call method. We could use dispatch

I like the sound of this restriction. While dispatch could perhaps be made to give a MethodError on others, it also seems fine to check at runtime and throw a descriptive error.

@darsnack
Copy link
Member

darsnack commented Aug 1, 2021

Just referencing a parallel conversation that has some relevance to this discussion: #1673 (comment)

I do feel that naming branches will be more useful than preventing the mismatched branches vs. args foot gun. For example, keeping track of categorical vs continuous pre-processing branches, or the plethora of branches in an inception block. We could always address the foot gun with a runtime check.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 1, 2021

  1. 1 layer, N inputs

Note also that this case currently isn't supported. Its zip stops early & ignores the remaining inputs:

julia> Parallel(hcat, (x->x.+1,))([0], [0], [0])
1-element Vector{Int64}:
 1

julia> Parallel(myplus, (identity,))([1], [10], [100])
1-element Vector{Int64}:
 1

@ToucheSir
Copy link
Member

My main worry with a dynamic runtime check is that one of AD, GPUCompiler or whatever new tracing functionality is coming down the pipe with Symbolics/tracing will really dislike it. Perhaps that's a non-issue though.

As for the not actually present case 3, take it as a feature request ;)

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Aug 23, 2021

I think the vararg version can be modified to do the pairwise thing by managing f when needed. It seems to be the expected way for Parallel to work too, since the idea is that you can execute N branches and manipulate the output of these together.

Inner constructors are usually bad for AD, and we definitely don't want to restrict what the types of the branches are. Storing layers as a tuple should suffice here.

Good to avoid runtime checks if we can get most of the way there without them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants