Skip to content

Commit 65009b2

Browse files
committed
add dropout
1 parent 57268d1 commit 65009b2

File tree

5 files changed

+202
-0
lines changed

5 files changed

+202
-0
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1112
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213

src/NNlib.jl

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ for f in ACTIVATIONS
4040
end
4141
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases
4242

43+
include("dropout.jl")
44+
export dropout, dropout!
45+
4346
include("softmax.jl")
4447
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
4548
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp

src/dropout.jl

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using Random, ChainRulesCore
2+
3+
"""
4+
dropout([rng], A, p; dims=:)
5+
6+
Returns an array in which each element of `A` is either replaced with zero,
7+
with probability `p`, or else multiplied by `1/(1-p)`.
8+
9+
By default every element is treated independently.
10+
With `dims=1`, a choice is made for every value of the 1st index
11+
i.e. each row of a matrix is either zero or not.
12+
13+
Optional first argument is the random number generator used.
14+
15+
# Examples
16+
```
17+
julia> dropout(ones(2, 10), 1/5)
18+
2×10 Matrix{Float64}:
19+
1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25
20+
1.25 1.25 1.25 0.0 1.25 1.25 0.0 1.25 1.25 1.25
21+
22+
julia> mean(dropout(ones(10^4, 5), 0.3), dims=1)
23+
1×5 Matrix{Float64}:
24+
0.996 1.00171 1.00629 0.998714 0.992429
25+
26+
julia> dropout(ones(5, 5), 0.7, dims=1) # whole row the same
27+
5×5 Matrix{Float64}:
28+
3.33333 3.33333 3.33333 3.33333 3.33333
29+
0.0 0.0 0.0 0.0 0.0
30+
0.0 0.0 0.0 0.0 0.0
31+
3.33333 3.33333 3.33333 3.33333 3.33333
32+
0.0 0.0 0.0 0.0 0.0
33+
34+
julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
35+
1×5 Matrix{Float64}:
36+
1.00571 1.00571 1.00571 1.00571 1.00571
37+
```
38+
"""
39+
dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)
40+
41+
function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
42+
T = float(eltype(A))
43+
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
44+
if p > 0
45+
dst = similar(A, T)
46+
pT = convert(real(T), p)
47+
_dropout!(rng, dst, A, pT, dims)
48+
else
49+
# Not so sure we want fast paths... this tries but doesn't guarantee type-stability,
50+
# and the rrule does not have such a fast paths.
51+
convert(AbstractArray{T}, A)
52+
end
53+
end
54+
55+
"""
56+
dropout!(B, A, p; dims=:)
57+
58+
This does exactly `B .= dropout(A, p; dims)`,
59+
or rather, it's the implementation of out-of-place [`dropout`](@ref).
60+
"""
61+
function dropout!(dst::AbstractArray, src::AbstractArray, p::Real; dims=:)
62+
size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input"))
63+
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
64+
if p > 0
65+
rng = _rng_from_array(A)
66+
pT = convert(real(eltype(dst)), p)
67+
_dropout!(rng, dst, src, pT, dims)
68+
else
69+
copyto!(dst, src)
70+
end
71+
end
72+
73+
# This is the easy case in that we can safely use the output array for random numbers.
74+
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)
75+
val = convert(eltype(dst), 1/(1-p))
76+
rand!(rng, dst)
77+
# dst .= (dst.>p) .* val .* src # hits a SIMD bug
78+
_fast_broadcast!(dst, src) do q, x
79+
(q>p) * val * x
80+
end
81+
dst
82+
end
83+
84+
# For other dims, we we do need to allocate something.
85+
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)
86+
tmp = similar(dst, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
87+
rand!(rng, tmp)
88+
val = convert(eltype(dst), 1/(1-p))
89+
# One-pass strategy:
90+
# dst .= (tmp.>p) .* val .* src
91+
# Two-pass strategy:
92+
_fast_broadcast!(tmp) do q
93+
(q>p) * val
94+
end
95+
dst .= tmp .* src
96+
end
97+
98+
# The gradient needs to keep the random choices made, thus store at least a BitArray,
99+
# but the following way turns out to be faster & simpler:
100+
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
101+
T = float(eltype(A))
102+
val = convert(T, 1/(1-p))
103+
keep = if dims isa Colon
104+
similar(A, T)
105+
else
106+
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))
107+
end
108+
rand!(rng, keep)
109+
Y = @. (keep>p) * A * val
110+
function dropout_back(Δ)
111+
dY = unthunk(Δ)
112+
dA = @. (keep>p) * dY * val
113+
(NoTangent(), NoTangent(), dA, NoTangent())
114+
end
115+
return Y, dropout_back
116+
end
117+
118+
"""
119+
_fast_broadcast!(f, x, y, z...)
120+
121+
This does `x .= f.(x, y, z...)`, but works around
122+
an issue with broadcasting that prevents SIMD in such cases.
123+
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
124+
125+
Not intended for general use. Does not check sizes!
126+
"""
127+
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
128+
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
129+
@simd ivdep for I in eachindex(bc)
130+
@inbounds x[I] = bc[I]
131+
end
132+
return x
133+
end
134+
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
135+
# CUDA does not suffer from this bug
136+
broadcast!(f, x, x, yz...)
137+
end
138+
139+
140+
"""
141+
_rng_from_array(x)
142+
143+
Return the random number generator most appropriate for `x`:
144+
`CUDA.default_rng()` for `CuArray`,
145+
else `Random.default_rng()`
146+
"""
147+
_rng_from_array(::AbstractArray) = Random.default_rng()
148+
# _rng_from_array(::CuArray) = CUDA.default_rng()
149+
150+
@non_differentiable _rng_from_array(::Any)
151+
152+

test/dropout.jl

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using NNlib, Test, Statistics, Random
2+
using Zygote, StableRNGs, ChainRulesCore
3+
4+
@testset "dropout" begin
5+
# Basics
6+
x1 = randn(Float32, 3, 4)
7+
@test size(@inferred dropout(x1, 0.1)) == (3, 4)
8+
@test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)
9+
@test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)
10+
@test eltype(dropout(x1, 0.1)) == Float32
11+
@test eltype(dropout(x1, 0.1; dims=1)) == Float32
12+
@test eltype(dropout(x1, 0.1; dims=(1,2))) == Float32
13+
14+
rng = Random.default_rng()
15+
@test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)
16+
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)
17+
18+
# Values
19+
d45 = dropout(trues(100, 100, 100), 0.45)
20+
@test mean(d45) 1 atol=1e-2
21+
dpi2 = dropout(fill(pi, 1000), 0.2)
22+
@test sort(unique(dpi2)) [0, 5pi/4]
23+
d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)
24+
@test sort(unique(vec(d33))) [0, 3/(1-0.3)]
25+
26+
# Gradient rule
27+
y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)
28+
dx = back(fill(3, 1000, 2))[3]
29+
@test !all(iszero, dx[:,2]) # this is why we save the random choices
30+
@test sort(unique(vec(dx))) [0, 3/(1-0.45)]
31+
32+
@testset "Zygote" begin
33+
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}
34+
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}
35+
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}
36+
37+
f1(x) = sum(dropout(x, 0.5))
38+
@test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse
39+
@test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)
40+
end
41+
end
42+

test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ include("test_utils.jl")
5252
include("ctc.jl")
5353
end
5454

55+
@testset "Dropout" begin
56+
include("dropout.jl")
57+
end
58+
5559
@testset "Fold/Unfold" begin
5660
include("fold.jl")
5761
end

0 commit comments

Comments
 (0)