@@ -16,27 +16,49 @@ using Zygote, StableRNGs, ChainRulesCore
16
16
@test size (@inferred dropout (rng, x1, 0.1 ; dims= 2 )) == (3 , 4 )
17
17
18
18
# Values
19
+ @test dropout (x1, 0 ) == x1
20
+ @test dropout (x1.+ 0im , 0 ) == x1
21
+ @test dropout (x1, 1 ) == zero .(x1)
22
+ @test dropout (x1.+ im, 1 ) == zero .(x1)
23
+
19
24
d45 = dropout (trues (100 , 100 , 100 ), 0.45 )
20
25
@test mean (d45) ≈ 1 atol= 1e-2
21
26
dpi2 = dropout (fill (pi , 1000 ), 0.2 )
22
27
@test sort (unique (dpi2)) ≈ [0 , 5pi / 4 ]
23
28
d33 = dropout (fill (3 , 10 , 1000 ), 0.3 , dims= 2 )
24
29
@test sort (unique (vec (d33))) ≈ [0 , 3 / (1 - 0.3 )]
25
30
31
+ # Complex -- not worth too much optimisation, but should work!
32
+ x2 = [1.0 + 0im ,2.0 + 1im ,3.0 + 3im ] # from Flux's tests
33
+ @test dropout (x, 0.5 ) isa Vector{ComplexF64}
34
+ @test dropout (x, 0.5 ; dims= 1 ) isa Vector{ComplexF64}
35
+
26
36
# Gradient rule
27
37
y, back = rrule (dropout, rng, hcat (trues (1000 ), falses (1000 )), 0.45 )
28
38
dx = back (fill (3 , 1000 , 2 ))[3 ]
29
39
@test ! all (iszero, dx[:,2 ]) # this is why we save the random choices
30
40
@test sort (unique (vec (dx))) ≈ [0 , 3 / (1 - 0.45 )]
31
41
42
+ y2, back2 = rrule (dropout, rng, x2, 0.5 )
43
+ @test y2 isa Vector{ComplexF64}
44
+ @test back2 (one .(y2))[3 ] isa Vector{ComplexF64}
45
+
32
46
@testset " Zygote" begin
33
47
@test Zygote. gradient (x -> sum (dropout (x, 0.3 )), x1)[1 ] isa Matrix{Float32}
34
48
@test Zygote. gradient (x -> sum (dropout (rng, x, 0.3 )), x1)[1 ] isa Matrix{Float32}
35
49
@test Zygote. gradient (x -> sum (dropout (x, 0.3 , dims= 1 )), x1)[1 ] isa Matrix{Float32}
36
50
51
+ # p=0 & p=1
52
+ @test Zygote. gradient (x -> sum (dropout (x, 0 )), x1)[1 ] == ones (3 ,4 )
53
+ @test Zygote. gradient (x -> sum (dropout (x, 1 )), x1)[1 ] == zeros (3 ,4 )
54
+
55
+ # Second order
37
56
f1 (x) = sum (dropout (x, 0.5 ))
38
57
@test_broken Zygote. hessian (f1, [1.0 ,2.0 ,3.0 ]) == zeros (3 , 3 ) # forward over reverse
39
58
@test Zygote. hessian_reverse (f1, [1.0 ,2.0 ,3.0 ]) == zeros (3 , 3 )
40
59
end
41
- end
42
60
61
+ # Errors
62
+ @test_throws ArgumentError dropout (x1, - 1 )
63
+ @test_throws ArgumentError dropout (x1, 2 )
64
+ end
0 commit comments