Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Vararg Chain (Chain of Parallel) #2101

Closed
wants to merge 2 commits into from

Conversation

cstjean
Copy link

@cstjean cstjean commented Nov 7, 2022

Closes #2100

As mentionned in #2100 (comment), this will break any code using Chain() as the identity function. I need a decision whether this is acceptable, or if I should special-case it.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Closes FluxML#2100

As mentionned in FluxML#2100 (comment),
this will break any code using `Chain()` as the identity function.
@ToucheSir
Copy link
Member

ToucheSir commented Nov 7, 2022

This reminded me that we discussed multi-input Chains just over 2 years ago, see discussion starting here. One unresolved issue is whether users will expect Chain(twoargfn, twoargfn) to work if Chain(twoargfn) does. It does not feel like we have enough information to know whether it is safe to splat the output of one layer into another vs passing it on wholesale.

@darsnack
Copy link
Member

darsnack commented Nov 7, 2022

This was also discussed in great detail in #1698 where there was a desire to remove (m::Parallel)(xs::Tuple) = m(xs...). There I suggested that this PR was necessary to make that happen.

A couple of details that are relevant from that discussion:

  • the Julia convention for multiple outputs is a Tuple with explicit splatting to downstream functions
  • A Splat layer is an option for convenient explicit splatting that avoids needing a (m)(xs::Tuple) = m(xs...) for every multi-arg layer

The issue (#2100) is a little ambiguous. You can insert Parallel at the end or the middle, and it "just" works. Only inserting at the very start of a Chain breaks the Vararg variant. If that's what's missing, then a simpler change:

(m::Chain)(xs...) = _applychain(m.layers, xs)

would work. This still keeps Chain as 1-in-1-out from passing arguments between layers, but optionally auto-tuplifies a Vararg input.

@cstjean
Copy link
Author

cstjean commented Nov 7, 2022

(m::Parallel)(xs::Tuple) = m(xs...)

Is this documented anywhere? If not, I can make a PR. It's a crucial part for combining embeddings and inputs.

In any case, then I agree with either the status quo, or the

(m::Chain)(xs...) = _applychain(m.layers, xs)

proposal.

@darsnack
Copy link
Member

darsnack commented Nov 7, 2022

Is this documented anywhere?

Surprisingly the Tuple <=> Vararg equivalence is not in the docstring for Parallel. It should be.

My initial comment was just to surface prior discussions on the topic. Having had the chance to review those discussions in more detail, I think we can condense to two options:

  1. Chain is MIMO (i.e. it accepts Vararg and always splats the intermediate arguments). In this option, multiple input-outputs are automatically handled.
  2. Chain is SISO (status quo). Multiple arguments are explicitly handled via one of two sub-options:
    a. Each multiple argument layer must add a (m)(xs::Tuple) dispatch
    b. We have Splat to make it explicit via the user instead of the layer author

In the case of (2), the (m::Chain)(xs...) = _applychain(m.layers, xs) method can be added for convenience.

(2a) is what we currently have, and I am okay with all options. Maybe the other maintainers can weigh in.

@cstjean cstjean closed this Nov 7, 2022
@ToucheSir
Copy link
Member

Ideally MIMO would just work, but unfortunately I think splatting intermediates would break models like Chain(mimo_fn, identity, ...). I have no objections to the convenience method, however.

@darsnack
Copy link
Member

darsnack commented Nov 7, 2022

@cstjean were you planning on opening another PR with the convenience method discussed above?

@cstjean
Copy link
Author

cstjean commented Nov 8, 2022

Sure, I can do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Chain(Parallel(...), ...)
3 participants