-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace unrolled foldl
used to evaluate Chain
with a better one
#1809
Conversation
Would there be any value in defining a (the above need not hold up this PR, just thinking about next steps and how to bring something like |
The obvious candidate there is Agree the unstable path would be nice to have, too. This is just the minimum-effort PR to get started. |
Should we have some second derivative test since you expressed concerns about it in FluxML/Zygote.jl#1126 ? |
That would be great, thanks for bringing it up.
…On Mon, 13 Dec 2021 at 22:08, Carlo Lucibello ***@***.***> wrote:
Should we have some second derivative test since you expressed concerns about it in FluxML/Zygote.jl#1126 ?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub, or unsubscribe.
Triage notifications on the go with GitHub Mobile for iOS or Android.
|
src/layers/basic.jl
Outdated
@@ -43,11 +43,18 @@ end | |||
|
|||
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) | |||
|
|||
function (c::Chain)(x) | |||
if order() < 2 | |||
foldl((y,f) -> f(y), (x, c.layers...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be too cheeky to do
foldl((y,f) -> f(y), (x, c.layers...)) | |
Base.afoldl(|>, x, c.layers...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was slow when I tried it :( although I did not think of |>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tidied up, here are times of some options on some very small networks:
julia> using Flux
julia> m1 = Chain(x->vcat(x,x), Dense(2,2,relu));
julia> m2 = Chain(Dense(20,30,relu), Dense(30,10), softmax);
# Tagged version
julia> @btime gradient(sum∘$m1, 3.14)
16.584 μs (253 allocations: 9.52 KiB)
(0.33772213757038116,)
julia> @btime gradient(sum∘$m2, rand(20,20));
42.792 μs (353 allocations: 68.00 KiB)
# foldl
julia> @eval Flux (c::Chain)(x) = foldl(|>, (x, c.layers...));
julia> Flux.Zygote.refresh()
julia> @btime gradient(sum∘$m1, 3.14)
18.625 μs (277 allocations: 11.62 KiB)
(0.33772213757038116,)
julia> @btime gradient(sum∘$m2, rand(20,20));
45.500 μs (385 allocations: 72.03 KiB)
# afoldl
julia> @eval Flux (c::Chain)(x) = Base.afoldl(|>, x, c.layers...);
julia> Flux.Zygote.refresh()
julia> @btime gradient(sum∘$m1, 3.14)
84.167 μs (496 allocations: 19.03 KiB)
(0.33772213757038116,)
julia> @btime gradient(sum∘$m2, rand(20,20));
117.917 μs (614 allocations: 77.62 KiB)
# Unstable for loop
julia> @eval Flux begin
function (c::Chain)(x)
y = identity(x)
for f in c.layers
y = f(y)
end
y
end
end
julia> Flux.Zygote.refresh()
julia> @btime gradient(sum∘$m1, 3.14)
25.250 μs (357 allocations: 13.33 KiB)
(0.33772213757038116,)
julia> @btime gradient(sum∘$m2, rand(20,20));
58.792 μs (512 allocations: 72.89 KiB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Looks like afoldl
would need an rrule
like foldl
to be remotely competitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's a sad reflection on Zygote that something so simple should benefit from a complicated hand-written rule. Diffractor:
julia> fs = (sin, cos, x -> vcat(x,x), sum, exp);
julia> @btime Zygote.gradient(x -> foldl(|>, (x, $fs...)), 3)
6.900 μs (130 allocations: 3.88 KiB)
(2.017262629514019,)
julia> @btime Zygote.gradient(x -> Base.afoldl(|>, x, $fs...), 3)
72.083 μs (306 allocations: 11.81 KiB)
(2.017262629514019,)
julia> @btime Diffractor.gradient(x -> foldl(|>, (x, $fs...)), 3)
9.375 μs (205 allocations: 6.52 KiB)
(2.017262629514019,)
julia> @btime Diffractor.gradient(x -> Base.afoldl(|>, x, $fs...), 3)
3.812 μs (90 allocations: 5.56 KiB)
(2.017262629514019,)
And [now with ReverseDiff too]
julia> @btime ForwardDiff.derivative(x -> foldl(|>, (x, $fs...)), 3)
689.597 ns (11 allocations: 416 bytes)
2.0172626295140184
julia> @btime ForwardDiff.derivative(x -> Base.afoldl(|>, x, $fs...), 3)
181.812 ns (6 allocations: 256 bytes)
2.0172626295140184
julia> fs2 = (sin, cos, x -> [x,x], sum, exp);
julia> @btime Tracker.gradient(x -> foldl(|>, (x, $fs2...)), 3)
850.134 ns (22 allocations: 768 bytes)
(2.017262629514019 (tracked),)
julia> @btime Tracker.gradient(x -> Base.afoldl(|>, x, $fs2...), 3)
294.457 ns (22 allocations: 768 bytes)
(2.017262629514019 (tracked),)
julia> @btime ReverseDiff.gradient(x -> foldl(|>, (x[1], $fs2...)), [3.0])
700.290 ns (24 allocations: 1.02 KiB)
1-element Vector{Float64}:
2.017262629514019
julia> @btime ReverseDiff.gradient(x -> Base.afoldl(|>, x[1], $fs2...), [3.0])
448.652 ns (24 allocations: 1.02 KiB)
1-element Vector{Float64}:
2.017262629514019
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, indeed. This does understand structs, but fails on my foldl
tests above, not sure why
julia> Yota.grad(x -> sum(x.a), (a=[1,2], b=[3,4]))
(3, (ChainRulesCore.ZeroTangent(), Tangent{NamedTuple{(:a, :b), Tuple{Vector{Int64}, Vector{Int64}}}}(a = [1, 1],)))
julia> Yota.grad(x -> foldl(|>, (x, fs2...)), 3.0)
ERROR: No deriative rule found for op %7 = foldl(|>, %5)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(foldl), ::typeof(|>), ::Tuple{Float64, typeof(sin), typeof(cos), var"#13#14", typeof(sum), typeof(exp)}) = ...
julia> Yota.grad(x -> Base.afoldl(|>, x, fs2...), 3.0)
ERROR: MethodError: no method matching length(::typeof(sin))
julia> m2 = Chain(Dense(2,3,sigmoid));
julia> Yota.grad(m -> sum(m([0.1, 0.2])), m2)
ERROR: No deriative rule found for op %15 = broadcasted(%12, %14)::Broadcasted{}, try defining it using ...
Cc @dfdx in case he'd like to join this hallway discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The performance is really only due to zygote unrolling chains, which isn't per se necessary. It also seems foldl
wouldn't work generically, esp when composing layers chains of chains etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest version of this PR replaces foldl
with @generated
applychain. Applied to the tiny examples above #1809 (comment) this is faster. And it has shorter compile times (shown elsewhere).
julia> using Flux
julia> m1 = Chain(x->vcat(x,x), Dense(2,2,relu));
julia> m2 = Chain(Dense(20,30,relu), Dense(30,10), softmax);
julia> @btime gradient(sum∘$m1, 3.14)
min 14.917 μs, mean 16.347 μs (246 allocations, 9.42 KiB)
(0.310660183429718,)
julia> @btime gradient(sum∘$m2, rand(20,20));
min 41.417 μs, mean 50.995 μs (353 allocations, 68.27 KiB)
julia> m1 = Chain([x->vcat(x,x), Dense(2,2,relu)]);
julia> m2 = Chain([Dense(20,30,relu), Dense(30,10), softmax]);
julia> @btime gradient(sum∘$m1, 3.14)
min 34.333 μs, mean 35.980 μs (459 allocations, 15.06 KiB)
(0.0,)
julia> @btime gradient(sum∘$m2, rand(20,20));
min 73.500 μs, mean 85.917 μs (681 allocations, 76.66 KiB)
The last times here show a chain with layers stored as a vector. This is not type-stable, but helps with compile times. On very small networks like this, it is significantly slower, so it probably cannot be the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: latest master of Yota now correctly handles foldl
. I haven't tested it with the latest version of this PR though since I'm currently busy speeding up the underlying tracer faster, so I'm going to wait till the PR is done and catch up later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should simplify this, because there are different rules now how composed chains would work.
src/layers/basic.jl
Outdated
@@ -43,11 +43,18 @@ end | |||
|
|||
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) | |||
|
|||
function (c::Chain)(x) | |||
if order() < 2 | |||
foldl((y,f) -> f(y), (x, c.layers...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The performance is really only due to zygote unrolling chains, which isn't per se necessary. It also seems foldl
wouldn't work generically, esp when composing layers chains of chains etc.
Latest version of this replaces julia> @time loss(lr_images); # first Flux.Chain call, cold start, no Zygote
4.490344 seconds (30.96 M allocations: 1.634 GiB, 7.28% gc time, 99.81% compilation time) # Flux v0.12.8
3.608894 seconds (34.26 M allocations: 1.801 GiB, 8.63% gc time, 99.84% compilation time) # @generated
3.699612 seconds (34.25 M allocations: 1.800 GiB, 8.80% gc time, 99.82% compilation time)
julia> @time loss(lr_images);
0.005024 seconds (2.46 k allocations: 242.812 KiB)
julia> @btime loss($lr_images);
min 7.202 ms, mean 7.648 ms (2443 allocations, 256.98 KiB) # Flux v0.12.8
min 5.533 ms, mean 5.696 ms (2461 allocations, 242.78 KiB) # @generated
min 5.639 ms, mean 5.881 ms (2368 allocations, 233.55 KiB) # Vector chain
julia> loss_grad(x, ps) = gradient(() -> loss(x), ps);
julia> ps = Flux.params(model);
julia> @time loss_grad(lr_images, ps); # first Zygote.gradient, cold start
62.895168 seconds (67.77 M allocations: 3.523 GiB, 0.99% gc time, 99.92% compilation time) # Flux v0.12.8
29.646482 seconds (64.41 M allocations: 3.373 GiB, 2.02% gc time, 99.88% compilation time) # @generated
13.677340 seconds (63.15 M allocations: 3.308 GiB, 3.99% gc time, 99.74% compilation time) # Vector chain
julia> @time loss_grad(lr_images, ps); # second
0.028015 seconds (12.95 k allocations: 2.275 MiB)
julia> @btime $loss_grad($lr_images, $ps);
min 37.933 ms, mean 38.574 ms (13052 allocations, 2.93 MiB) # Flux v0.12.8
min 29.283 ms, mean 29.730 ms (12941 allocations, 2.28 MiB) # @generated
min 30.516 ms, mean 30.957 ms (13436 allocations, 1.62 MiB) # Vector chain
julia> length(model.layers[1].layers.layers) # count of resblock(channels) layers
15 Also included in this comparison is a Chain backed by a Vector. This is type-unstable, and not as fast for very small networks. But for larger ones, it is no slower, and improves compile time by another factor of 2. It's included in this PR for now, but obviously needs tests if it is to stay. |
foldl
to evaluate Chain
foldl
used to evaluate Chain
with a better one
# This is a temporary and naive implementation | ||
# it might be replaced in the future for better performance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, in addition to a hand-written foldl
, the function Flux.activations
is just accumulate(|>, m1.layers; init=x1)
. Since we don't support Julia < 1.5, we could just replace it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought accumulate
would face the same issue as foldl
, namely that the rrule
doesn't consider init
? This PR need not be concerned with activations
either way, we can kick that can down the road until rrules
are tweaked or someone complains about performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, you could likewise do accumulate(|>, (x, m.layers...))
. But happy to leave it alone for now.
For sciml we might have regressions in higher order chains or regressions in smaller chains. |
@ChrisRackauckas would you be able to offer a thumbs up/down on whether this is something SciML could tolerate? Also @pxl-th I don't remember if you've tested this, but if not it may be interesting for your benchmark in FluxML/Zygote.jl#1119. |
This is fine. We use |
Codecov Report
@@ Coverage Diff @@
## master #1809 +/- ##
==========================================
+ Coverage 74.50% 74.57% +0.07%
==========================================
Files 28 28
Lines 1706 1715 +9
==========================================
+ Hits 1271 1279 +8
- Misses 435 436 +1
Continue to review full report at Codecov.
|
Good to go? |
This uses Base'sfoldl
to feed data through aChain
, instead of a hand-written version. The reason to do so is that this greatly reduced time-to-first-gradient in FluxML/Zygote.jl#1126 . It's possible that Flux ought to bound ChainRules with this. Or at least I should check whether Zygote's bound is high enough to guarantee it'll work. But first see if CI likes this.Latest version does not call either
foldl
orBase.afoldl
, but instead replaces one hand-written version with another, as this seems to perform better.