Skip to content

Commit 55e246d

Browse files
committed
nan & complex fixes
1 parent 4111cb2 commit 55e246d

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

src/dropout.jl

+10-8
Original file line numberDiff line numberDiff line change
@@ -72,23 +72,25 @@ end
7272

7373
# This is the easy case in that we can safely use the output array for random numbers.
7474
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)
75-
val = convert(eltype(dst), 1/(1-p))
75+
T = real(eltype(dst))
76+
val = convert(T, 1/(1-p))
7677
rand!(rng, dst)
7778
## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast!
7879
# dst .= (dst.>p) .* val .* src
7980
_fast_broadcast!(dst, src) do q, x
80-
(q>p) * val * x
81+
((real(q)>p) * val) * x
8182
end
8283
dst
8384
end
8485

8586
# For other dims, we we do need to allocate something.
8687
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)
87-
tmp = similar(dst, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
88+
T = real(eltype(dst))
89+
tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
8890
rand!(rng, tmp)
89-
val = convert(eltype(dst), 1/(1-p))
91+
val = convert(T, 1/(1-p))
9092
## One-pass strategy -- faster on GPU
91-
dst .= (tmp.>p) .* val .* src
93+
dst .= ((tmp.>p) .* val) .* src
9294
## Two-pass strategy -- slightly faster on some CPUs?
9395
# _fast_broadcast!(tmp) do q
9496
# (q>p) * val
@@ -99,18 +101,18 @@ end
99101
# The gradient needs to keep the random choices made, thus store at least a BitArray,
100102
# but the following way turns out to be faster & simpler:
101103
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
102-
T = float(eltype(A))
104+
T = float(real(eltype(A)))
103105
val = convert(T, 1/(1-p))
104106
keep = if dims isa Colon
105107
similar(A, T)
106108
else
107109
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))
108110
end
109111
rand!(rng, keep)
110-
Y = @. (keep>p) * A * val
112+
Y = @. ((keep>p) * val) * A
111113
function dropout_back(Δ)
112114
dY = unthunk(Δ)
113-
dA = @. (keep>p) * dY * val
115+
dA = @. ((keep>p) * val) * dY
114116
(NoTangent(), NoTangent(), dA, NoTangent())
115117
end
116118
return Y, dropout_back

test/dropout.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,49 @@ using Zygote, StableRNGs, ChainRulesCore
1616
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)
1717

1818
# Values
19+
@test dropout(x1, 0) == x1
20+
@test dropout(x1.+0im, 0) == x1
21+
@test dropout(x1, 1) == zero.(x1)
22+
@test dropout(x1.+im, 1) == zero.(x1)
23+
1924
d45 = dropout(trues(100, 100, 100), 0.45)
2025
@test mean(d45) 1 atol=1e-2
2126
dpi2 = dropout(fill(pi, 1000), 0.2)
2227
@test sort(unique(dpi2)) [0, 5pi/4]
2328
d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)
2429
@test sort(unique(vec(d33))) [0, 3/(1-0.3)]
2530

31+
# Complex -- not worth too much optimisation, but should work!
32+
x2 = [1.0+0im,2.0+1im,3.0+3im] # from Flux's tests
33+
@test dropout(x, 0.5) isa Vector{ComplexF64}
34+
@test dropout(x, 0.5; dims=1) isa Vector{ComplexF64}
35+
2636
# Gradient rule
2737
y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)
2838
dx = back(fill(3, 1000, 2))[3]
2939
@test !all(iszero, dx[:,2]) # this is why we save the random choices
3040
@test sort(unique(vec(dx))) [0, 3/(1-0.45)]
3141

42+
y2, back2 = rrule(dropout, rng, x2, 0.5)
43+
@test y2 isa Vector{ComplexF64}
44+
@test back2(one.(y2))[3] isa Vector{ComplexF64}
45+
3246
@testset "Zygote" begin
3347
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}
3448
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}
3549
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}
3650

51+
# p=0 & p=1
52+
@test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4)
53+
@test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4)
54+
55+
# Second order
3756
f1(x) = sum(dropout(x, 0.5))
3857
@test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse
3958
@test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)
4059
end
41-
end
4260

61+
# Errors
62+
@test_throws ArgumentError dropout(x1, -1)
63+
@test_throws ArgumentError dropout(x1, 2)
64+
end

0 commit comments

Comments
 (0)