-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant compile time latency in Flux with default optimization #1126
Comments
With Julia master I get this:
If I replace the long tuple with a vector, very crudely: struct VChain
layers
VChain(xs...) = new(collect(xs))
end
Flux.@forward VChain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys
Flux.functor(::Type{<:VChain}, c) = c.layers, ls -> VChain(ls...)
(c::VChain)(x) = applychain(c.layers, x)
applychain(fs, x) = isempty(fs) ? x : applychain(fs[2:end], first(fs)(x)) ... then I get better times. Which is perhaps a similar story to the
|
Interesting that it allocates less as well. I wonder if switching from |
Oh right, using Flux
struct FChain{T}
layers::T
FChain(xs...) = new{typeof(xs)}(xs)
FChain(xs::AbstractVector) = new{typeof(xs)}(xs) # allows FChain(Any[Layer, Layer, ...]), type-unstable
end
Flux.@forward FChain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys
Flux.functor(::Type{<:FChain}, c) = c.layers, ls -> FChain(ls...)
Flux.functor(::Type{<:FChain{<:AbstractVector}}, c) = c.layers, ls -> FChain(ls) # maybe?
# (c::FChain)(x) = foldl((y,f) -> f(y), c.layers; init=x) # NO, this forgets the gradient for x
(c::FChain)(x) = foldl((y,f) -> f(y), (x, c.layers...))
(c::FChain{<:AbstractVector})(x) = foldl((y,f) -> f(y), vcat([x], c.layers)))
That's on a cold start. By accident, running this example after some unrelated checks (
So perhaps 10s of this is generic startup (and perhaps that could be helped by precompilation?) and not the effects which scale. Comparable warm start for the original Chain:
|
There are precompilation statements for explicit params, but not for the version of |
Most of the code is shared between the two modes, and the cached precompilation would be cached anyway so I don't know how much of a difference it would make. |
Thank you all for insights. I am impressed how easy it is to implemented an alternative Chain structure. The proposed foldl approach works also well in my original SRGAN implementation. |
The A bigger change would be to allow, or even automatically switch to, a deliberately type-unstable approach. This could be nicer (and prob. faster) than my crude |
@Alexander-Barth if you have them, I'd love to see how the times look now on 1.6/1.7. If we can think of some easy tests to ensure |
True, allowing |
@ToucheSir Of course, here are the times for Julia 1.7.0-rc3 (first and 2nd call of pullback)
Julia 1.6.3:
So the timing are very similar. Surprisingly `VChain' is a bit faster: Julia 1.7.0-rc3 (first and 2nd call of pullback)
Julia 1.6.3:
|
Thanks for checking. Interesting that the 2nd run with FChain is 10x slower than with VChain. Do they settle down (or |
Indeed with
|
Is |
I think it's both, really -- by convention neither package believes keyword arguments should have gradients. This is perhaps the most convincing example I've seen where the gradient would be desired. Cc @oxinabox for design thoughts? I edited the code above to make it the first element |
In theory, I think we can write rules for gradients w.r.t keyword arguments by writing rules for:
But I have never tested that.
I think it might work if the version that takes kwargs without gradients isn't defined though. |
Here's the effect on the above example of adding I believe that most of the speedup is coming from adding it to Zygote, it would be worth separating the effect of each of these. We should also benchmark some functions with more Zygote and less BLAS, I guess, to see if this has bad effects elsewhere: julia> # @eval Flux (c::Chain)(x) = Base.afoldl(|>, x, c.layers...)
@time loss(lr_images); # first Flux.Chain call, cold start, no Zygote
4.260597 seconds (30.96 M allocations: 1.634 GiB, 7.06% gc time, 99.79% compilation time) # tagged
4.621733 seconds (31.05 M allocations: 1.638 GiB, 7.00% gc time, 99.78% compilation time) # Base.Experimental IRTools + Zygote
julia> @time loss(lr_images);
0.007167 seconds (2.47 k allocations: 259.031 KiB)
julia> @btime loss($lr_images);
min 7.199 ms, mean 7.768 ms (1882 allocations, 203.14 KiB) # tagged
min 7.373 ms, mean 7.585 ms (2432 allocations, 255.34 KiB) # tagged, later
julia> loss_grad(x, ps) = gradient(() -> loss(x), ps);
julia> ps = Flux.params(model);
julia> @time loss_grad(lr_images, ps); # first Zygote.gradient, cold start
61.904481 seconds (67.77 M allocations: 3.523 GiB, 1.08% gc time, 99.88% compilation time) # tagged
14.570048 seconds (63.99 M allocations: 3.330 GiB, 3.65% gc time, 99.67% compilation time) # Base.Experimental IRTools + Zygote
14.803469 seconds (64.02 M allocations: 3.331 GiB, 3.65% gc time, 99.66% compilation time) # Base.Experimental IRTools + Zygote + ZygoteRules
15.253339 seconds (64.76 M allocations: 3.369 GiB, 3.76% gc time, 99.68% compilation time) # Base.Experimental Zygote
julia> @time loss_grad(lr_images, ps); # second
0.036655 seconds (13.07 k allocations: 2.928 MiB)
0.037986 seconds (13.14 k allocations: 2.983 MiB)
julia> @btime $loss_grad($lr_images, $ps);
min 38.108 ms, mean 38.486 ms (13060 allocations, 2.93 MiB) # tagged
min 37.827 ms, mean 38.431 ms (13024 allocations, 2.92 MiB) # tagged, later
min 39.340 ms, mean 40.376 ms (13090 allocations, 2.98 MiB) # Base.Experimental IRTools + Zygote
min 40.165 ms, mean 41.283 ms (13122 allocations, 2.98 MiB) # Base.Experimental IRTools + Zygote + ZygoteRules
min 39.469 ms, mean 40.169 ms (13124 allocations, 2.98 MiB) # Base.Experimental Zygote
julia> length(model.layers[1].layers.layers) # count of resblock(channels) layers
15
##############################
# With Base.afoldl for Chain:
julia> @time loss(lr_images); # first Flux.Chain call, cold start, no Zygote
3.790401 seconds (30.96 M allocations: 1.634 GiB, 8.04% gc time, 99.83% compilation time) # tagged
3.732328 seconds (30.98 M allocations: 1.635 GiB, 7.78% gc time, 99.83% compilation time) # Base.Experimental Zygote
julia> @time loss(lr_images);
0.005440 seconds (2.47 k allocations: 245.625 KiB)
julia> @btime loss($lr_images);
min 5.854 ms, mean 6.082 ms (2453 allocations, 243.89 KiB) # tagged
min 6.009 ms, mean 6.153 ms (2462 allocations, 245.47 KiB) # Base.Experimental Zygote
julia> @time loss_grad(lr_images, ps); # first Zygote.gradient, cold start
21.698580 seconds (78.74 M allocations: 4.044 GiB, 3.35% gc time, 99.76% compilation time) # tagged
14.395047 seconds (75.68 M allocations: 3.890 GiB, 4.68% gc time, 99.68% compilation time) # Base.Experimental IRTools + Zygote
13.474056 seconds (75.70 M allocations: 3.891 GiB, 4.54% gc time, 99.66% compilation time) # Base.Experimental IRTools + Zygote + ZygoteRules
14.054168 seconds (76.63 M allocations: 3.934 GiB, 4.57% gc time, 99.68% compilation time) # Base.Experimental Zygote
julia> @btime $loss_grad($lr_images, $ps);
min 41.539 ms, mean 43.545 ms (17348 allocations, 1.80 MiB) # tagged
min 40.177 ms, mean 40.682 ms (17390 allocations, 1.81 MiB) # tagged, later
min 40.027 ms, mean 40.519 ms (17439 allocations, 1.82 MiB) # Base.Experimental IRTools + Zygote
min 41.294 ms, mean 46.930 ms (16991 allocations, 1.78 MiB) # Base.Experimental IRTools + Zygote + ZygoteRules
min 40.574 ms, mean 41.111 ms (17421 allocations, 1.82 MiB) # Base.Experimental Zygote
########################
# With `foldl` instead:
julia> # @eval Flux (c::Chain)(x) = Base.afoldl(|>, x, c.layers...)
@eval Flux (c::Chain)(x) = foldl(|>, (x, c.layers...))
julia> @time loss(lr_images); # first Flux.Chain call, cold start, no Zygote
4.412313 seconds (32.94 M allocations: 1.749 GiB, 7.79% gc time, 99.76% compilation time) # tagged
4.478969 seconds (32.97 M allocations: 1.750 GiB, 6.88% gc time, 99.81% compilation time) # Base.Experimental Zygote
julia> @time loss(lr_images);
0.008333 seconds (2.50 k allocations: 267.562 KiB)
julia> @btime loss($lr_images);
min 8.338 ms, mean 8.797 ms (2452 allocations, 262.05 KiB)
min 8.208 ms, mean 8.842 ms (2480 allocations, 265.52 KiB)
julia> loss_grad(x, ps) = gradient(() -> loss(x), ps);
julia> ps = Flux.params(model);
julia> @time loss_grad(lr_images, ps); # first Zygote.gradient, cold start
45.873403 seconds (72.27 M allocations: 3.775 GiB, 1.52% gc time, 99.86% compilation time) # tagged
30.731573 seconds (69.52 M allocations: 3.630 GiB, 2.07% gc time, 99.82% compilation time) # Base.Experimental Zygote
julia> @time loss_grad(lr_images, ps); # second
0.048166 seconds (14.88 k allocations: 4.123 MiB)
julia> @btime $loss_grad($lr_images, $ps);
min 49.859 ms, mean 51.324 ms (14801 allocations, 4.12 MiB) # tagged
min 46.516 ms, mean 48.572 ms (14894 allocations, 4.09 MiB) # Base.Experimental Zygote
julia> length(model.layers[1].layers.layers) # count of resblock(channels) layers
15 |
There are most certainly better test cases, but off the top of my head RNNs (especially pointwise-heavy cells like |
Ineed, maybe from FluxML/Flux.jl#1761 . (Which still isn't merged?) |
Pre-#1761 might be better, as AD is more involved there. |
A much simpler test of some complicated broadcasting shows 3-10x slowdowns: #1147 (comment) . So I doubt Another idea is to replace |
Closing in favor of #1119 , please reopen if the 2 issues are not the same |
I am trying to optimize a super-resolution GAN, and I see very significant compile time.
This issue is discussed here:
https://discourse.julialang.org/t/significant-compile-time-latency-in-flux-with-a-gan/68518
@ToucheSir was able to reduce this issue to the code below (thanks a lot!).
I am using Julia 1.6.3 on Linux and these packages:
It appears that this issue is sensitive to the optimization level and to the julia version used.
with julia 1.6.3 (default optimization):
Note the timing of the first call the
loss_grad(lr_images, ps)
.In julia 1.5.3 I get the following timing:
with "julia 1.6.3 -O1":
This issue might be related
#1119
The text was updated successfully, but these errors were encountered: