Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace unrolled foldl used to evaluate Chain with a better one #1809

Merged
merged 8 commits into from
Feb 5, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Dec 13, 2021

This uses Base's foldl to feed data through a Chain, instead of a hand-written version. The reason to do so is that this greatly reduced time-to-first-gradient in FluxML/Zygote.jl#1126 . It's possible that Flux ought to bound ChainRules with this. Or at least I should check whether Zygote's bound is high enough to guarantee it'll work. But first see if CI likes this.

Latest version does not call either foldl or Base.afoldl, but instead replaces one hand-written version with another, as this seems to perform better.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 14, 2021

Would there be any value in defining a foldl wrapper or foldl-like function that has init as a positional arg? That would allow for creating an equivalent (possibly simplified) version of https://github.com/JuliaDiff/ChainRules.jl/blob/a75193768775975fac5578c89d1e5f50d7f358c2/src/rulesets/Base/mapreduce.jl#L342-L377.

(the above need not hold up this PR, just thinking about next steps and how to bring something like VChain into Flux)

@mcabbott
Copy link
Member Author

The obvious candidate there is Base.mapfoldr_impl(f, op, init, itr), which is 2 hops down from foldl. I've thought about it but haven't quite got there yet.

Agree the unstable path would be nice to have, too. This is just the minimum-effort PR to get started.

@CarloLucibello
Copy link
Member

Should we have some second derivative test since you expressed concerns about it in FluxML/Zygote.jl#1126 ?

@CarloLucibello CarloLucibello added this to the v0.13 milestone Dec 14, 2021
@ToucheSir
Copy link
Member

ToucheSir commented Dec 14, 2021 via email

@@ -43,11 +43,18 @@ end

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

function (c::Chain)(x)
if order() < 2
foldl((y,f) -> f(y), (x, c.layers...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too cheeky to do

Suggested change
foldl((y,f) -> f(y), (x, c.layers...))
Base.afoldl(|>, x, c.layers...)

ref. https://github.com/JuliaLang/julia/blob/dba8f0344ab6f1a63ea4f7928610feb35f014d4f/base/operators.jl#L592.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was slow when I tried it :( although I did not think of |>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tidied up, here are times of some options on some very small networks:

julia> using Flux

julia> m1 = Chain(x->vcat(x,x), Dense(2,2,relu));

julia> m2 = Chain(Dense(20,30,relu), Dense(30,10), softmax);

# Tagged version

julia> @btime gradient(sum$m1, 3.14)
  16.584 μs (253 allocations: 9.52 KiB)
(0.33772213757038116,)

julia> @btime gradient(sum$m2, rand(20,20));
  42.792 μs (353 allocations: 68.00 KiB)

# foldl

julia> @eval Flux (c::Chain)(x) = foldl(|>, (x, c.layers...));

julia> Flux.Zygote.refresh()

julia> @btime gradient(sum$m1, 3.14)
  18.625 μs (277 allocations: 11.62 KiB)
(0.33772213757038116,)

julia> @btime gradient(sum$m2, rand(20,20));
  45.500 μs (385 allocations: 72.03 KiB)

# afoldl

julia> @eval Flux (c::Chain)(x) = Base.afoldl(|>, x, c.layers...);

julia> Flux.Zygote.refresh()

julia> @btime gradient(sum$m1, 3.14)
  84.167 μs (496 allocations: 19.03 KiB)
(0.33772213757038116,)

julia> @btime gradient(sum$m2, rand(20,20));
  117.917 μs (614 allocations: 77.62 KiB)

# Unstable for loop

julia> @eval Flux begin
       function (c::Chain)(x)
         y = identity(x)
         for f in c.layers
           y = f(y)
         end
         y
       end
       end

julia> Flux.Zygote.refresh()

julia> @btime gradient(sum$m1, 3.14)
  25.250 μs (357 allocations: 13.33 KiB)
(0.33772213757038116,)

julia> @btime gradient(sum$m2, rand(20,20));
  58.792 μs (512 allocations: 72.89 KiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Looks like afoldl would need an rrule like foldl to be remotely competitive.

Copy link
Member Author

@mcabbott mcabbott Jan 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's a sad reflection on Zygote that something so simple should benefit from a complicated hand-written rule. Diffractor:

julia> fs = (sin, cos, x -> vcat(x,x), sum, exp);

julia> @btime Zygote.gradient(x -> foldl(|>, (x, $fs...)), 3)
  6.900 μs (130 allocations: 3.88 KiB)
(2.017262629514019,)

julia> @btime Zygote.gradient(x -> Base.afoldl(|>, x, $fs...), 3)
  72.083 μs (306 allocations: 11.81 KiB)
(2.017262629514019,)

julia> @btime Diffractor.gradient(x -> foldl(|>, (x, $fs...)), 3)
  9.375 μs (205 allocations: 6.52 KiB)
(2.017262629514019,)

julia> @btime Diffractor.gradient(x -> Base.afoldl(|>, x, $fs...), 3)
  3.812 μs (90 allocations: 5.56 KiB)
(2.017262629514019,)

And [now with ReverseDiff too]

julia> @btime ForwardDiff.derivative(x -> foldl(|>, (x, $fs...)), 3)
  689.597 ns (11 allocations: 416 bytes)
2.0172626295140184

julia> @btime ForwardDiff.derivative(x -> Base.afoldl(|>, x, $fs...), 3)
  181.812 ns (6 allocations: 256 bytes)
2.0172626295140184

julia> fs2 = (sin, cos, x -> [x,x], sum, exp);

julia> @btime Tracker.gradient(x -> foldl(|>, (x, $fs2...)), 3)
  850.134 ns (22 allocations: 768 bytes)
(2.017262629514019 (tracked),)

julia> @btime Tracker.gradient(x -> Base.afoldl(|>, x, $fs2...), 3)
  294.457 ns (22 allocations: 768 bytes)
(2.017262629514019 (tracked),)

julia> @btime ReverseDiff.gradient(x -> foldl(|>, (x[1], $fs2...)), [3.0])
  700.290 ns (24 allocations: 1.02 KiB)
1-element Vector{Float64}:
 2.017262629514019

julia> @btime ReverseDiff.gradient(x -> Base.afoldl(|>, x[1], $fs2...), [3.0])
  448.652 ns (24 allocations: 1.02 KiB)
1-element Vector{Float64}:
 2.017262629514019

Copy link
Member Author

@mcabbott mcabbott Jan 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, indeed. This does understand structs, but fails on my foldl tests above, not sure why

julia> Yota.grad(x -> sum(x.a), (a=[1,2], b=[3,4]))
(3, (ChainRulesCore.ZeroTangent(), Tangent{NamedTuple{(:a, :b), Tuple{Vector{Int64}, Vector{Int64}}}}(a = [1, 1],)))

julia> Yota.grad(x -> foldl(|>, (x, fs2...)), 3.0)
ERROR: No deriative rule found for op %7 = foldl(|>, %5)::Float64, try defining it using 
	ChainRulesCore.rrule(::typeof(foldl), ::typeof(|>), ::Tuple{Float64, typeof(sin), typeof(cos), var"#13#14", typeof(sum), typeof(exp)}) = ...

julia> Yota.grad(x -> Base.afoldl(|>, x, fs2...), 3.0)
ERROR: MethodError: no method matching length(::typeof(sin))

julia> m2 = Chain(Dense(2,3,sigmoid));

julia> Yota.grad(m -> sum(m([0.1, 0.2])), m2)
ERROR: No deriative rule found for op %15 = broadcasted(%12, %14)::Broadcasted{}, try defining it using ...

Cc @dfdx in case he'd like to join this hallway discussion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance is really only due to zygote unrolling chains, which isn't per se necessary. It also seems foldl wouldn't work generically, esp when composing layers chains of chains etc.

Copy link

@dfdx dfdx Jan 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott The rule for foldl exists in ChainRules, so it must be a bug. I'm looking into it. Thanks for pinging me!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest version of this PR replaces foldl with @generated applychain. Applied to the tiny examples above #1809 (comment) this is faster. And it has shorter compile times (shown elsewhere).

julia> using Flux

julia> m1 = Chain(x->vcat(x,x), Dense(2,2,relu));

julia> m2 = Chain(Dense(20,30,relu), Dense(30,10), softmax);

julia> @btime gradient(sum$m1, 3.14)
  min 14.917 μs, mean 16.347 μs (246 allocations, 9.42 KiB)
(0.310660183429718,)

julia> @btime gradient(sum$m2, rand(20,20));
  min 41.417 μs, mean 50.995 μs (353 allocations, 68.27 KiB)

julia> m1 = Chain([x->vcat(x,x), Dense(2,2,relu)]);

julia> m2 = Chain([Dense(20,30,relu), Dense(30,10), softmax]);

julia> @btime gradient(sum$m1, 3.14)
  min 34.333 μs, mean 35.980 μs (459 allocations, 15.06 KiB)
(0.0,)

julia> @btime gradient(sum$m2, rand(20,20));
  min 73.500 μs, mean 85.917 μs (681 allocations, 76.66 KiB)

The last times here show a chain with layers stored as a vector. This is not type-stable, but helps with compile times. On very small networks like this, it is significantly slower, so it probably cannot be the default.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: latest master of Yota now correctly handles foldl. I haven't tested it with the latest version of this PR though since I'm currently busy speeding up the underlying tracer faster, so I'm going to wait till the PR is done and catch up later.

@mcabbott mcabbott marked this pull request as draft January 4, 2022 00:51
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should simplify this, because there are different rules now how composed chains would work.

@@ -43,11 +43,18 @@ end

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

function (c::Chain)(x)
if order() < 2
foldl((y,f) -> f(y), (x, c.layers...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance is really only due to zygote unrolling chains, which isn't per se necessary. It also seems foldl wouldn't work generically, esp when composing layers chains of chains etc.

@mcabbott
Copy link
Member Author

Latest version of this replaces foldl with an @generated function which explicitly unrolls. This is not only faster to compile, it is also faster to run. With the example from FluxML/Zygote.jl#1126 (comment), the final gradient is 30% faster, and the first one is 2x faster:

julia> @time loss(lr_images);  # first Flux.Chain call, cold start, no Zygote
  4.490344 seconds (30.96 M allocations: 1.634 GiB, 7.28% gc time, 99.81% compilation time)  # Flux v0.12.8
  3.608894 seconds (34.26 M allocations: 1.801 GiB, 8.63% gc time, 99.84% compilation time)  # @generated
  3.699612 seconds (34.25 M allocations: 1.800 GiB, 8.80% gc time, 99.82% compilation time)

julia> @time loss(lr_images);
  0.005024 seconds (2.46 k allocations: 242.812 KiB)

julia> @btime loss($lr_images);
  min 7.202 ms, mean 7.648 ms (2443 allocations, 256.98 KiB)  # Flux v0.12.8
  min 5.533 ms, mean 5.696 ms (2461 allocations, 242.78 KiB)  # @generated
  min 5.639 ms, mean 5.881 ms (2368 allocations, 233.55 KiB)  # Vector chain

julia> loss_grad(x, ps) = gradient(() -> loss(x), ps);

julia> ps = Flux.params(model);

julia> @time loss_grad(lr_images, ps);  # first Zygote.gradient, cold start
 62.895168 seconds (67.77 M allocations: 3.523 GiB, 0.99% gc time, 99.92% compilation time)  # Flux v0.12.8
 29.646482 seconds (64.41 M allocations: 3.373 GiB, 2.02% gc time, 99.88% compilation time)  # @generated
 13.677340 seconds (63.15 M allocations: 3.308 GiB, 3.99% gc time, 99.74% compilation time)  # Vector chain

julia> @time loss_grad(lr_images, ps);  # second
  0.028015 seconds (12.95 k allocations: 2.275 MiB)

julia> @btime $loss_grad($lr_images, $ps);
  min 37.933 ms, mean 38.574 ms (13052 allocations, 2.93 MiB)  # Flux v0.12.8
  min 29.283 ms, mean 29.730 ms (12941 allocations, 2.28 MiB)  # @generated
  min 30.516 ms, mean 30.957 ms (13436 allocations, 1.62 MiB)  # Vector chain

julia> length(model.layers[1].layers.layers)  # count of resblock(channels) layers
15

Also included in this comparison is a Chain backed by a Vector. This is type-unstable, and not as fast for very small networks. But for larger ones, it is no slower, and improves compile time by another factor of 2. It's included in this PR for now, but obviously needs tests if it is to stay.

@mcabbott mcabbott changed the title Use foldl to evaluate Chain Replace unrolled foldl used to evaluate Chain with a better one Jan 11, 2022
Comment on lines 81 to 82
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, in addition to a hand-written foldl, the function Flux.activations is just accumulate(|>, m1.layers; init=x1). Since we don't support Julia < 1.5, we could just replace it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought accumulate would face the same issue as foldl, namely that the rrule doesn't consider init? This PR need not be concerned with activations either way, we can kick that can down the road until rrules are tweaked or someone complains about performance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, you could likewise do accumulate(|>, (x, m.layers...)). But happy to leave it alone for now.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 11, 2022

For sciml we might have regressions in higher order chains or regressions in smaller chains.

@ToucheSir
Copy link
Member

@ChrisRackauckas would you be able to offer a thumbs up/down on whether this is something SciML could tolerate? Also @pxl-th I don't remember if you've tested this, but if not it may be interesting for your benchmark in FluxML/Zygote.jl#1119.

@ChrisRackauckas
Copy link
Member

This is fine. We use FastChain for small chains anyways, and resort to Flux when things get big, so we should be fine. The small end gap will grow after SciML/DiffEqFlux.jl#671, and I'm not sure what to do after that divergence but it'll be an interesting discussion to have.

@mcabbott mcabbott marked this pull request as ready for review February 5, 2022 02:40
@codecov-commenter
Copy link

codecov-commenter commented Feb 5, 2022

Codecov Report

Merging #1809 (f60da1a) into master (9b21e2c) will increase coverage by 0.07%.
The diff coverage is 83.33%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1809      +/-   ##
==========================================
+ Coverage   74.50%   74.57%   +0.07%     
==========================================
  Files          28       28              
  Lines        1706     1715       +9     
==========================================
+ Hits         1271     1279       +8     
- Misses        435      436       +1     
Impacted Files Coverage Δ
src/layers/show.jl 72.72% <71.42%> (+0.35%) ⬆️
src/layers/basic.jl 77.53% <90.90%> (+0.61%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9b21e2c...f60da1a. Read the comment docs.

src/layers/show.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member Author

mcabbott commented Feb 5, 2022

Good to go?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants