-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move dropout
to NNlib
#2150
Move dropout
to NNlib
#2150
Changes from 4 commits
e96e4ef
fec2d8e
6c84a6c
eab0b15
4ab93b3
0e396a6
f42f475
d7cc49d
28ac4c4
9e99422
fc9855b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,7 @@ LayerNorm | |
InstanceNorm | ||
GroupNorm | ||
Flux.normalise | ||
Flux.dropout | ||
NNlib.dropout | ||
``` | ||
|
||
### Test vs. Train | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,85 @@ | ||
|
||
# Internal function, used only for layers defined in this file. | ||
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active | ||
|
||
_dropout_shape(s, ::Colon) = size(s) | ||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) | ||
|
||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) | ||
|
||
""" | ||
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true) | ||
|
||
The dropout function. If `active` is `true`, | ||
for each input, either sets that input to `0` (with probability | ||
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions, | ||
e.g. `dims=1` applies dropout along columns and `dims=2` along rows. | ||
If `active` is `false`, it just returns the input `x`. | ||
|
||
Specify `rng` for custom RNGs instead of the default RNG. | ||
Note that custom RNGs are only supported on the CPU. | ||
|
||
Warning: when using this function, you have to manually manage the activation | ||
state. Usually in fact, dropout is used while training | ||
but is deactivated in the inference phase. This can be | ||
automatically managed using the [`Dropout`](@ref) layer instead of the | ||
`dropout` function. | ||
|
||
The [`Dropout`](@ref) layer is what you should use in most scenarios. | ||
""" | ||
function dropout(rng, x, p; dims=:, active::Bool=true) | ||
active || return x | ||
y = dropout_mask(rng, x, p, dims=dims) | ||
return x .* y | ||
end | ||
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) | ||
|
||
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
dropout_mask(rng, x::CuArray, p; kwargs...) = | ||
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
function _dropout_mask(rng, x, p; dims=:) | ||
realfptype = float(real(eltype(x))) | ||
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) | ||
y .= _dropout_kernel.(y, p, 1 - p) | ||
return y | ||
end | ||
|
||
# TODO move this to NNlib | ||
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) | ||
|
||
""" | ||
Dropout(p; dims=:, rng = default_rng_value()) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Dropout layer. | ||
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability. | ||
This is used as a regularisation, i.e. to reduce overfitting. | ||
|
||
While training, for each input, this layer either sets that input to `0` (with probability | ||
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the | ||
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input | ||
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during | ||
training. | ||
While training, it sets each input to `0` (with probability `p`) | ||
or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function. | ||
While testing, it has no effect. | ||
|
||
In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more | ||
details. | ||
By defaul the mode will switch automatically, but it can also | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
be controlled manually via [`Flux.testmode!`](@ref). | ||
|
||
Specify `rng` to use a custom RNG instead of the default. | ||
Custom RNGs are only supported on the CPU. | ||
By default every input is treated independently. The `dims` keyword | ||
instead takes a random choice only along that dimension. | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input | ||
(also called 2D dropout). | ||
|
||
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. | ||
Keyword `rng` lets you specify a custom random number generator. | ||
(Only supported on the CPU.) | ||
|
||
# Examples | ||
```jldoctest | ||
julia> m = Chain(Dense(1 => 1), Dropout(1)); | ||
```julia | ||
julia> m = Chain(Dense(ones(3,2)), Dropout(0.4)) | ||
Chain( | ||
Dense(2 => 3), # 9 parameters | ||
Dropout(0.4), | ||
) | ||
|
||
julia> Flux.trainmode!(m); | ||
julia> m(ones(2, 7)) # test mode, no effect | ||
3×7 Matrix{Float64}: | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
|
||
julia> y = m([1]); | ||
julia> Flux.trainmode!(m); # would happen within gradient | ||
|
||
julia> y == [0] | ||
true | ||
julia> m(ones(2, 7)) | ||
3×7 Matrix{Float64}: | ||
0.0 0.0 3.33333 0.0 0.0 0.0 0.0 | ||
3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333 | ||
3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333 | ||
|
||
julia> m = Chain(Dense(1000 => 1000), Dropout(0.5)); | ||
julia> y = m(ones(2, 10_000)); | ||
|
||
julia> Flux.trainmode!(m); | ||
julia> using Statistics | ||
|
||
julia> y = m(ones(1000)); | ||
julia> mean(y) # is about 2.0, as for test mode | ||
1.9892222222222182 | ||
|
||
julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1) | ||
true | ||
julia> mean(iszero, y) # is about 0.4 | ||
0.40323333333333333 | ||
``` | ||
""" | ||
mutable struct Dropout{F,D,R<:AbstractRNG} | ||
mutable struct Dropout{F<:Real,D,R<:AbstractRNG} | ||
p::F | ||
dims::D | ||
active::Union{Bool, Nothing} | ||
rng::R | ||
end | ||
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) | ||
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is intentional but the error checking seems to only apply to the keyword based constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's the only "public" one. I have no idea why we have this 3-arg constructor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's my recollection. |
||
|
||
function Dropout(p; dims=:, rng = default_rng_value()) | ||
@assert 0 ≤ p ≤ 1 | ||
function Dropout(p::Real; dims=:, rng = default_rng_value()) | ||
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1, got p = $p")) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if p isa Integer # Dropout(0) | ||
return p==0 ? identity : zero | ||
end | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Dropout(p, dims, nothing, rng) | ||
end | ||
|
||
@functor Dropout | ||
trainable(a::Dropout) = (;) | ||
|
||
function (a::Dropout)(x) | ||
_isactive(a, x) || return x | ||
return dropout(a.rng, x, a.p; dims=a.dims, active=true) | ||
if _isactive(a, x) && a.p != 0 | ||
dropout(a.rng, x, a.p; dims=a.dims) | ||
else | ||
x | ||
end | ||
end | ||
|
||
testmode!(m::Dropout, mode=true) = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this to preserve the
active
keyword. Not entirely sure whether use of that was supported outside Flux.The exported function is this one, not the NNlib one, and I think it lacks a docstring at present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it to FluxML/NNlib.jl#452 because backends need it, but to my knowledge none do for dropout so this seems fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You think it might be OK to just go without the deprecation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure. A quick JuliaHub search didn't turn up anything, but I've enabled downstream tests in case we want to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now 0e396a6 removes this entirely. At least to see if anything changes.