-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Parallel Vararg #1698
Make Parallel Vararg #1698
Conversation
This seems out of sync with master? We already have https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl#L447-L449. |
Yeah, there was a small conflict, but I fixed that. |
master: julia> Parallel(+, Dense(3,3))
Error showing value of type Parallel{typeof(+), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}:
ERROR: MethodError: no method matching iterate(::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
...
Stacktrace:
[1] trainable(m::Parallel{typeof(+), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}})
@ Flux ~/Downloads/new_clones/Flux.jl/src/layers/basic.jl:456 |
This will be breaking since the previous semantics only required that We should still make this change though. |
Of course, the N>3 case would need all the outputs. Maybe vararg is the wrong thing to call this. |
Not quite, if there are 5 branches, then |
Agree we should do this, and that it's breaking. While breaking things, it should probably be adjusted to handle the 3rd case of #1685
Needs tests & doc updates, obviously. |
I think the last case is something better handled on the user side. If someone has multiple inputs, it's easier to have a method that accepts a tuple of arguments and forward that as necessary than it is for us to guess where these arguments go. (m::Layer)(x::Tuple) = m(x[1], x[2], ...) For the struct MultiInput{T}
W::T
end
(m::MultiInput)(x, y) = m.W * x, m.W * y, x * y
x = (rand(3,3), rand(3,3))
l = Parallel((x...) -> identity.(x), # This will be annoying to deal with
MultiInput(rand(3,3)),
MultiInput(rand(3,3))) This would need two things: One is that we remove Another case that this handles better is N inputs M outputs. All these cases are subsets of treating inputs and outputs as something generic that is better left to the user. |
As discussed here, the tuple method exists to make I am not sure what 1 Layer, N Inputs is even supposed to do? Replicate the layer onto each input? That seems out of the contract for |
Right and the distinction to make there is considering tuple as a single input always. If users want the elements of the tuple to be inputs, they can splat, else pass the tuple along. That way we don't have to wrap tuples multiple times so that the automatic splat produces the correct tuple inputs to layers that expect multiple inputs. One other case would be how |
Wrapping the output like
Users don't control how |
A way to handle the 1 Layer, N Inputs case is to do |
If by replicate you mean "apply" and not "make a copy of, and then apply", then yes. What the layer that does this should be called is up for debate, but our discussion over RE splatting/multiarg over tuples, I feel it is such a deep rabbit hole that we ought to avoid it as much as possible. Trying to guess user intent is nothing if not fraught, and it's better to be strict and consistent rather than inconsistently lenient. |
Yeah, that's what I meant, and the RNN discussion is what was in the back of my head.
Just to clarify, what's consistent here? As I see it, Julia itself, |
|
Currently, |
I don't think anyone disagrees with this. But regardless of how we pass multiple arguments, the expected (and requested) behavior of If we want to be really strict about "multiple inputs/outputs are always tuples" then we should eliminate the |
Correct. Passing along the composite type is what we are doing here. par = Par(f, l1, l2) Said another way, the tuple method doesn't override the case that a function expects multiple arguments, |
No it isn't. Under this PR, the only way for a You're thinking about MIMO in the context of branches, but missing MIMO in the context of the complete |
See the edit, i was replying to Brian, my browser hadn't updated the comments yet. |
Okay so if I have |
It is likely that a model (produces multiple outputs/ receives multiple inputs) in the
I understand what you mean by this. The question is: how to distinguish a tuple expected to be sent to the layers as is from a tuple that needs to be mapped. The current design will always flatten a tuple input, breaking the contract of N inputs. For the MIMO case, master would mean |
I wonder if, instead of having each individual layer handle this, we define a common wrapper layer which solely handles splatting tuples into varargs: struct Splat{T} # arbitrarily chosen name
inner::T
end
@functor Splat
(s::Splat)(x) = s.inner(x...) # rough version, probably needs more error checking Then instead of having The main concern I have with this is compilation overhead. AIUI splatting large tuples is quite slow, and having a bunch of differing input lengths would also trigger recompilation. If those turn out to be a non-issue, however, then I would advocate for a separate layer. |
My main concern is this makes simple operators like splats closer to DSLs, so I would avoid such an approach. |
It's a fine line, isn't it? We had a similar discussion in #1289, and though I was very much against One thing to note is that this is one area other frameworks do very poorly in that we could do well. Maybe a dedicated splat layer (call it |
Well, we'll probably need a better reason or a motivating example. |
Is being stuck in PR review limbo a good one? Because I have a feeling the discussion above will be recapitulated every time we talk about adding new container layers... Edit: removed useless strained analogy, see my next comment about bringing this into a synchronous design discussion. |
Which contract? IMO master doesn't break any contracts, but the current version of this PR breaks the contract for
The solution here is to be explicit. I do see what you mean though — wrapping multiple outputs (which is a tuple) in another tuple will not help when there are multiple MIMO branches.
Instead of guessing how the user wants the inputs distributed, we stay strict to the (m::Parallel)(xs::Tuple) = m.connection(map((f, x) -> f(x), m.layers, xs)...)
(m::Parallel)(x) = m((x,))
(m::Parallel)(xs...) = m(xs) This makes it clear that the Let's use concrete examples to avoid confusion. Here are some cases: # SingleOutput produces something other than a tuple which goes to each Dense
Chain(SingleOutput, Parallel(+, Dense, Dense, Dense)
# MultiOutput produces multiple outputs as a tuple, each of which are passed to each Dense
Chain(MultiOutput, Parallel(+, Dense, Dense, Dense)
# MultiOutput produces multiple outputs
# We explicitly state that this should be kept as one unit to each MultiInput
Chain(MultiOutput, (x...) -> (x, x, x), Parallel(+, MultiInput, MultiInput, MultiInput)
# Each branch takes in differing number of inputs
Chain(ThreeOutput, (x, y, z) -> (x, (y, z)), Parallel(+, Dense, TwoInput))
# Multiple outputs from Parallel is easily handled by the combinator
Chain(Dense, Parallel(MIMO, Dense, Dense, Dense)) Note that stuff like Now, I agree that some of these, like What I am trying to point out is that it is not possible to go in the other direction where we take a tuple and break it into multiple outputs to hit the vararg case. Unless we use something like @ToucheSir's |
Not sure how to interpret Brian's comment, but to Kyle's point about the dense layers, I had a test that showed using multiple layers each receiving multiple inputs. I would prefer to get the invariant tuples done too, I think it's doable. Either way this PR is an improvement. Master will always flatten a tuple input, breaking the contract of N inputs, which I was trying to iterate on. |
I think your browser may have missed another comment update. My point was that we've had a number of PRs that are stalled on design disagreements, and even as somebody who didn't author any of them I feel some frustration about not being able to come to some decision. Perhaps we should arrange some time for every ML/AD call to walk through design discussions? Synchronous communication should be far more efficient, and it'd also address weeks when there aren't enough ecosystem updates to fill the alloted time. |
Funnily enough, if JuliaLang/julia#42717 lands then Base may have resolved this debate for us :) |
This PR makes it so vararg inputs and layers are treated as
zip(layers, inputs)
which are then splat into theconnection
.