Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify trainable, functor and Parallel #1862

Merged
merged 7 commits into from
Feb 5, 2022
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 5, 2022

This does a few things to do with Functors / Optimisers:

  • Changes all uses of trainable to return a NamedTuple, which is what Optimisers now wants.
  • Removes the use of trainable on Parallel, since all fields are trainable, the layers are just one wrapper deeper.
  • Removes the inner constructor from Chain, so that it can simply be @functor Chain. Fixes Chain forgets names under fmap #1857 .
  • Likewise removes the inner constructor from Maxout.
  • Adjusts the show code not to use trainable. This is most of why all these changes are in one PR.

The downside of no longer hiding the Tuple inside Chain from Functors.jl is that fcollect and hence Flux.modules will see it. This is why the tests fail. I'm not too sure what this is for, and whether this matters. Tests updated to allow this.

The PR also changes Parallel to call its connection exactly once, always. And to allow, with N layers, either 1 input or exactly N inputs. It used to zip, but (IMO) allowing N-1 inputs just seems like a bug magnet. These can be separated if anyone feels strongly. Fixes #1685, closes #1698.

@mcabbott mcabbott changed the title Functor Simplify trainable, functor and Parallel Feb 5, 2022
@darsnack
Copy link
Member

darsnack commented Feb 5, 2022

modules provides a flattened iterator for 1-level above the leaves. So it gives you the ordered list of nodes seen during the traversal. It's use case is regularizers like

sum(regularize(l) for l in modules(model))

where regularize can dispatch by layer type.

So I think for a breaking change, we can put a note in the docs/NEWS and update the tests, but this will not affect the expected use case.

@mcabbott mcabbott added this to the v0.13 milestone Feb 5, 2022
@mcabbott
Copy link
Member Author

mcabbott commented Feb 5, 2022

Ok. Possibly it should filter out all types owned by Base? With Parallel you get a tuple of layers, before this PR, which you don't get with Chain:

julia> Flux.modules(Chain(Parallel(+, Dense(2,2,sin), Dense(2,2,cos)), Dense(2,2,tan)))
6-element Vector{Any}:
 Chain(Parallel(+, Dense(2, 2, sin), Dense(2, 2, cos)), Dense(2, 2, tan))  # 18 parameters
 Parallel(+, Dense(2, 2, sin), Dense(2, 2, cos))  # 12 parameters
 (Dense(2, 2, sin), Dense(2, 2, cos))
 Dense(2, 2, sin)    # 6 parameters
 Dense(2, 2, cos)    # 6 parameters
 Dense(2, 2, tan)    # 6 parameters

Or perhaps that's just more confusing and staying closer to the literal structure is better.

@testset "Utils" begin
include("utils.jl")
end
@testset verbose=true "Flux.jl" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR also adds an overall testset, so that all tests are run even if one fails near the start.

@darsnack
Copy link
Member

darsnack commented Feb 5, 2022

Or perhaps that's just more confusing and staying closer to the literal structure is better.

I think so. And ultimately, only the last three elements above will matter for the expected use case.

@codecov-commenter
Copy link

Codecov Report

Merging #1862 (9ef0a46) into master (8d3b8d3) will increase coverage by 0.09%.
The diff coverage is 84.84%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1862      +/-   ##
==========================================
+ Coverage   73.85%   73.94%   +0.09%     
==========================================
  Files          28       28              
  Lines        1683     1689       +6     
==========================================
+ Hits         1243     1249       +6     
  Misses        440      440              
Impacted Files Coverage Δ
src/deprecations.jl 22.22% <ø> (ø)
src/layers/show.jl 72.36% <60.00%> (-1.25%) ⬇️
src/layers/basic.jl 76.74% <86.36%> (+1.94%) ⬆️
src/layers/normalise.jl 82.38% <100.00%> (ø)
src/layers/recurrent.jl 75.45% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8d3b8d3...9ef0a46. Read the comment docs.

src/layers/normalise.jl Show resolved Hide resolved
src/layers/basic.jl Show resolved Hide resolved
src/layers/basic.jl Show resolved Hide resolved
@mcabbott mcabbott merged commit 9b21e2c into FluxML:master Feb 5, 2022
@mcabbott mcabbott deleted the functor branch February 5, 2022 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Chain forgets names under fmap Parallel edge-cases
4 participants