-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add norm functions #452
base: master
Are you sure you want to change the base?
Add norm functions #452
Conversation
Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR452/ |
75b42c3
to
3a52deb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some small comments below from a first pass.
But I also have a gist where I was fiddling with some approaches, and can compare them:
https://gist.github.com/mcabbott/6154bb78b735e8f0a9348767a7d59c86
The tl;dr is that trying to fuse the gradient got me 1.45x faster & 29% the memory usage of Flux's approach. This PR gets 88%. Is it worth trying to bolt something like that on?
The CUDA routine for batchnorm seems to do even better, I don't know how. But I may be measuring it wrong... I did not really think about the batchnorm case. That seems to be a lot of the complexity here.
src/normalization.jl
Outdated
if ChainRulesCore.is_inplaceable_destination(running_mean) | ||
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) | ||
else | ||
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might it be simpler to just demand that these be mutable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea was this gets us SArray
support for relatively cheap. Granted I didn't actually test with StaticArrays, so coverage might not be happy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this in abstract, but I think it's quite a bit of complexity, and there would have to be a real gain over just storing an MArray (and ideally someone who wants this).
I'd have to think what happens if an array has wider type than expected... it will be down-converted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asking for parameter types to support similar
so that Flux can always construct a RunningStats
doesn't seem unreasonable. I also haven't throught about eltype mismatches.
src/normalization.jl
Outdated
error("both scale and bias must be provided or left as nothing") | ||
end | ||
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) | ||
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will pay to separate out the inv
& sqrt
on N things, instead of calling them N^2 times. This was the major forward speedup I found.
My thought was that this PR (and any future functions which advertise |
Ah I had not seen that. It looks optimised but I don't immediately see what it's actually doing! On the CPU it seems to give correct answers but is not especially efficient. However, I see some multi-arg The counter to trying to bolt on speed afterwards is that it may want different organisation. I haven't absorbed how this PR structures things, but for me, absorbing the |
It's likely optimized routines will want to replace everything and just fulfill the high-level |
It's focusing on GPU arrays. The idea is to use the GPUArrays' broadcast/mapreduce kernel for simple kernel fusion and reducing gpu allocation. So like in the layer norm forward funciton: function layer_norm(epsilon, alpha, beta, x)
# [...]
sum_sum2 = mapreduce(_x_x2, .+, x; dims=1, init = (zero(T), zero(T)))
return fma.(α, _normalize.(convert(T, 1//N), ϵ, x, sum_sum2), β)
end This only take two gpu kernel to compute the layer norm, and only 1 intermediate array is created ( |
Cool. I've updated my gist above to time things... this is indeed fast. On the GPU, I discover that On the CPU it does some work N^2 times, but the cost of that can be largely solved by What all this means for this PR, I don't know. I suppose that I think the basic implementation should (1) try to be efficient without doing anything exotic, Array and CuArray, (2) must not error on SArray, FillArray, etc, but little point optimising that, and (3) try to line up with cuDNN etc. routines to make such overloads easy, and perhaps try to make easy to hook on other overloads like using LoopVectorization. Whether this |
Thanks for timing it. Were the reported numbers for I don't know much about fastmath, but I vaguely remember it is considered evil? |
Not 100% sure but I think it made no difference on GPU. Weird beasts but computation often seems free compared to memory anyway. Not sure I have an opinion on how evil it is. Here it seem pretty safe, it may fail to propagate NaN via the variance, but you'll get it again from |
If one really wants to go wild with optimizations, it's possible to fuse the computation of
|
The issue of that approach is that we would need to go down to kernel programming with either CUDA.jl or KernelAbstraction.jl. And the input size would also affect the performance, so in Pytorch they implement multiple kernel and dispatch with the input type and size. I guess part of the success of triton also comes from their compiler? |
It would be interesting to know, because the kernels are so much shorter too. Wonder what performance we'd get with Julia versions. A Triton-like library would be a nice force multiplier in this ecosystem. |
I think this
On CPU without |
That doesn't look good. Maybe it's worth switching to Welford algo for that. |
It sounds like that's the done thing. Can it use |
It should be doable with |
function norm_stats(x, dims) | ||
μ = mean(x; dims) | ||
σ² = var(x; dims, mean = μ, corrected = false) | ||
return μ, σ² | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this function need not be closely tied to norm, it's just mean-and-var for any purpose you like. Saying what it does might be clearer than saying what it's for.
So I'd propose, for name & signature, following Statistics closely. And returning a NamedTuple?
function norm_stats(x, dims) | |
μ = mean(x; dims) | |
σ² = var(x; dims, mean = μ, corrected = false) | |
return μ, σ² | |
end | |
function mean_var(x::AbstractArray; dims=:, corrected::Bool=true) | |
μ = mean(x; dims) | |
σ2 = var(x; dims, mean=μ, corrected) | |
(; mean=μ, var=σ2) | |
end |
I almost wonder if this function should live upstream somewhere... like Statistics?
My current attempt at a one-pass GPU version, which can happily overload this signature, is here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason it lives here, is called norm_stats
instead of mean_var
and doesn't take keyword args are one and the same: we need a Zygote-friendly function that selectively looks up running stats or calculates them. https://github.com/FluxML/NNlib.jl/blob/bc/norm-functions/src/normalization.jl#L108 doesn't really work with kwargs, so ; dims=...
is not an option. Thus it'd be a little weird to advertise this function as a general purpose fused mean + var, because it doesn't follow the same interface.
What's the status here? Can we delay optimization/generalization to future PRs and focus on approximately porting existing functionality if that's what needed to move on? |
Kind ping on the status of this PR as it will unblock support for AMDGPU backend (and we can specialize batchnorm using MIOpen for which I plan to open a PR). |
These roughly correspond to Flux's `*Norm` layers.
8c9b6ea
to
2ce9c42
Compare
These roughly correspond to Flux's
*Norm
layers.Constitutes the bulk of the work in #19. Dropout is a different beast and deserving of a separate discussion/its own PR. In the meantime, this should give NNlib{CUDA,ROC} a common interface to implement and allow Flux to start divesting its own norm function helpers.
Design Notes
The affine parameters are included as part of the function signature here because backends like
cuDNN
can fuse them into the main layer computation. This differs from Flux's use ofScale
inLayerNorm
, so how do we best reconcile the two? Activation functions are not included of this API for a similar reason (backends don't handle them).Another thing not addressed in this PR is alternative memory layouts/dimension orderings such as channels-first. Hopefully once we figure out a way to represent inputs with different layouts, we can add internal extension points that dispatch on those. For now, everything is assumed to be
spatial dims... x channels [x batch]
.Relatedly, juggling dims is an absolute pain and both the most painful and time-consuming part of developing this PR. I wish there was a way to use named dims while not requiring users or backend implementors to do the same. Something for future work.
PR Checklist