Skip to content

Commit 2bae421

Browse files
authored
Add dropout (#454)
* add dropout * tidy up * nan & complex fixes * test dropout! and allow rng * fixup * fix 1.6
1 parent ccf1732 commit 2bae421

File tree

5 files changed

+244
-1
lines changed

5 files changed

+244
-1
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.13"
3+
version = "0.8.14"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1112
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213

src/NNlib.jl

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ChainRulesCore
66
import ChainRulesCore: rrule
77
using Base.Broadcast: broadcasted
88
using Base.Threads
9+
using Random
910
using Statistics
1011
using Statistics: mean
1112
using LinearAlgebra
@@ -40,6 +41,9 @@ for f in ACTIVATIONS
4041
end
4142
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases
4243

44+
include("dropout.jl")
45+
export dropout, dropout!
46+
4347
include("softmax.jl")
4448
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
4549
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp

src/dropout.jl

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
2+
"""
3+
dropout([rng], A, p; [dims])
4+
5+
Returns an array in which each element of `A` is either replaced with zero,
6+
with probability `p`, or else multiplied by `1/(1-p)`.
7+
8+
By default every element is treated independently.
9+
With keyword `dims=1`, a choice is made for every value of the 1st index
10+
i.e. each row of a matrix is either zero or not.
11+
12+
Optional first argument is the random number generator used.
13+
14+
# Examples
15+
```
16+
julia> dropout(ones(2, 10), 0.2)
17+
2×10 Matrix{Float64}:
18+
1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25
19+
1.25 1.25 1.25 0.0 1.25 1.25 0.0 1.25 1.25 1.25
20+
21+
julia> mean(dropout(ones(10^4, 5), 0.2), dims=1)
22+
1×5 Matrix{Float64}:
23+
0.998 1.00075 0.99125 0.99575 1.00075
24+
25+
julia> dropout(ones(5, 5), 0.7, dims=1) # whole row the same
26+
5×5 Matrix{Float64}:
27+
3.33333 3.33333 3.33333 3.33333 3.33333
28+
0.0 0.0 0.0 0.0 0.0
29+
0.0 0.0 0.0 0.0 0.0
30+
3.33333 3.33333 3.33333 3.33333 3.33333
31+
0.0 0.0 0.0 0.0 0.0
32+
33+
julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
34+
1×5 Matrix{Float64}:
35+
1.00571 1.00571 1.00571 1.00571 1.00571
36+
```
37+
"""
38+
dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)
39+
40+
function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
41+
T = float(eltype(A))
42+
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
43+
if p > 0
44+
dst = similar(A, T, size(A))
45+
pT = convert(real(T), p)
46+
_dropout!(rng, dst, A, pT, dims)
47+
else
48+
# Not so sure we want fast paths... this tries but doesn't guarantee type-stability,
49+
# and the rrule does not have such a fast paths.
50+
convert(AbstractArray{T}, A)
51+
end
52+
end
53+
54+
"""
55+
dropout!(B, A, p; dims=:)
56+
57+
This does exactly `B .= dropout(A, p; dims)`,
58+
or rather, it's the implementation of out-of-place [`dropout`](@ref).
59+
"""
60+
dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_from_array(B), B, A, p; dims)
61+
62+
function dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:)
63+
size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input"))
64+
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
65+
if p > 0
66+
pT = convert(real(eltype(dst)), p)
67+
_dropout!(rng, dst, src, pT, dims)
68+
else
69+
# This fast path isn't free, but no concerns about types changing:
70+
copyto!(dst, src)
71+
end
72+
end
73+
74+
# This is the easy case in that we can safely use the output array for random numbers.
75+
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)
76+
T = real(eltype(dst))
77+
val = convert(T, 1/(1-p))
78+
rand!(rng, dst)
79+
## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast!
80+
# dst .= (dst.>p) .* val .* src
81+
_fast_broadcast!(dst, src) do q, x
82+
((real(q)>p) * val) * x
83+
end
84+
dst
85+
end
86+
87+
# For other dims, we we do need to allocate something.
88+
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)
89+
T = real(eltype(dst))
90+
tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
91+
rand!(rng, tmp)
92+
val = convert(T, 1/(1-p))
93+
## One-pass strategy -- faster on GPU
94+
dst .= ((tmp.>p) .* val) .* src
95+
## Two-pass strategy -- slightly faster on some CPUs?
96+
# _fast_broadcast!(tmp) do q
97+
# (q>p) * val
98+
# end
99+
# dst .= tmp .* src
100+
end
101+
102+
# The gradient needs to keep the random choices made, thus store at least a BitArray,
103+
# but the following way turns out to be faster & simpler:
104+
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
105+
T = float(real(eltype(A)))
106+
val = convert(T, 1/(1-p))
107+
keep = if dims isa Colon
108+
similar(A, T, size(A))
109+
else
110+
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))
111+
end
112+
rand!(rng, keep)
113+
Y = @. ((keep>p) * val) * A
114+
function dropout_back(Δ)
115+
dY = unthunk(Δ)
116+
dA = @. ((keep>p) * val) * dY
117+
(NoTangent(), NoTangent(), dA, NoTangent())
118+
end
119+
return Y, dropout_back
120+
end
121+
# Possibly TODO: another approach to the gradient would be to copy the RNG
122+
# and then re-generate the same mask, instead of storing it. This saves memory
123+
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
124+
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402
125+
126+
"""
127+
_fast_broadcast!(f, x, y, z...)
128+
129+
This does `x .= f.(x, y, z...)`, but works around
130+
an issue with broadcasting that prevents SIMD in such cases.
131+
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
132+
133+
Not intended for general use. Does not check sizes!
134+
"""
135+
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
136+
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
137+
@simd ivdep for I in eachindex(bc)
138+
@inbounds x[I] = bc[I]
139+
end
140+
return x
141+
end
142+
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
143+
# CUDA does not suffer from this bug
144+
broadcast!(f, x, x, yz...)
145+
end
146+
147+
148+
"""
149+
_rng_from_array(x)
150+
151+
Return the random number generator most appropriate for `x`:
152+
`CUDA.default_rng()` for `CuArray`, else `Random.default_rng()`
153+
"""
154+
_rng_from_array(::AbstractArray) = Random.default_rng()
155+
156+
@non_differentiable _rng_from_array(::Any)
157+

test/dropout.jl

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using NNlib, Test, Statistics, Random, LinearAlgebra
2+
using Zygote, StableRNGs, ChainRulesCore
3+
4+
@testset "dropout" begin
5+
# Basics
6+
x1 = randn(Float32, 3, 4)
7+
@test size(@inferred dropout(x1, 0.1)) == (3, 4)
8+
@test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)
9+
@test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)
10+
@test eltype(dropout(x1, 0.1)) == Float32
11+
@test eltype(dropout(x1, 0.1; dims=1)) == Float32
12+
@test eltype(dropout(x1, 0.1; dims=(1,2))) == Float32
13+
14+
rng = Random.default_rng()
15+
@test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)
16+
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)
17+
18+
x2 = Diagonal(randn(Float32, 10)) # Just to check it runs on weird matrices.
19+
if VERSION > v"1.8-" # on 1.6 this makes a sparse array.
20+
@test dropout(x2, 0.3) isa Matrix{Float32} # does not infer, but that's OK?
21+
end
22+
23+
# Values
24+
@test dropout(x1, 0) == x1
25+
@test dropout(x1.+0im, 0) == x1
26+
@test dropout(x1, 1) == zero.(x1)
27+
@test dropout(x1.+im, 1) == zero.(x1)
28+
29+
d45 = dropout(trues(100, 100, 100), 0.45)
30+
@test mean(d45) 1 atol=1e-2
31+
dpi2 = dropout(fill(pi, 1000), 0.2)
32+
@test sort(unique(dpi2)) [0, 5pi/4]
33+
d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)
34+
@test sort(unique(vec(d33))) [0, 3/(1-0.3)]
35+
36+
# Complex -- not worth too much optimisation, but should work!
37+
x2 = [1.0+0im,2.0+1im,3.0+3im] # from Flux's tests
38+
@test dropout(x2, 0.5) isa Vector{ComplexF64}
39+
@test dropout(x2, 0.5; dims=1) isa Vector{ComplexF64}
40+
41+
# Gradient rule
42+
y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)
43+
dx = back(fill(3, 1000, 2))[3]
44+
@test !all(iszero, dx[:,2]) # this is why we save the random choices
45+
@test sort(unique(vec(dx))) [0, 3/(1-0.45)]
46+
47+
y2, back2 = rrule(dropout, rng, x2, 0.5)
48+
@test y2 isa Vector{ComplexF64}
49+
@test back2(one.(y2))[3] isa Vector{ComplexF64}
50+
51+
@testset "Zygote" begin
52+
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}
53+
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}
54+
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}
55+
56+
# p=0 & p=1
57+
@test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4)
58+
@test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4)
59+
60+
# Second order
61+
f1(x) = sum(dropout(x, 0.5))
62+
@test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse
63+
@test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)
64+
end
65+
66+
# Bang
67+
y1 = fill!(similar(x1), NaN)
68+
@test dropout!(y1, x1, 0.0) == x1
69+
@test y1 == x1
70+
@test dropout!(rng, y1, x1, 1) == zero(x1)
71+
@test y1 == zero(x1)
72+
73+
# Errors
74+
@test_throws ArgumentError dropout(x1, -1)
75+
@test_throws ArgumentError dropout(x1, 2)
76+
@test_throws ArgumentError dropout!(y1, x1, 3)
77+
end

test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ include("test_utils.jl")
5252
include("ctc.jl")
5353
end
5454

55+
@testset "Dropout" begin
56+
include("dropout.jl")
57+
end
58+
5559
@testset "Fold/Unfold" begin
5660
include("fold.jl")
5761
end

0 commit comments

Comments
 (0)