-
-
Notifications
You must be signed in to change notification settings - Fork 216
Adding complex broadcasting for gradients on the GPU #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
807d689
2972faf
51dc882
0635ba4
739e896
a0e21e6
5a83493
6742644
2aa06c6
851ab33
f42d940
95a6b5b
b29f090
40fdb29
15c33ad
5e53ada
2c4857b
9fc2180
c685798
efc4f67
51e3ba3
c888db8
83ed917
7b0044b
2bb3b65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ end | |
| g_gpu = gradient(x -> v(x, 7), a_gpu)[1] | ||
| @test g_gpu isa CuArray | ||
| @test g_gpu |> collect ≈ g | ||
|
|
||
| w(x) = sum(broadcast(log, x)) | ||
| g = gradient(x -> w(x), a)[1] | ||
| g_gpu = gradient(x -> w(x), a_gpu)[1] | ||
|
|
@@ -38,7 +38,7 @@ end | |
| @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) | ||
| g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression | ||
| @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 | ||
| @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
| @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
|
|
||
| # Projection: eltype preservation: | ||
| @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} | ||
|
|
@@ -90,40 +90,40 @@ end | |
| @testset "gradient algebra" begin | ||
| w, b = rand(2) |> cu, rand(2) |> cu | ||
| x1, x2 = rand(2) |> cu, rand(2) |> cu | ||
| gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
| gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
|
||
| gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
| gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
|
||
| @test .- gs1 isa Grads | ||
| @test gs1 .- gs2 isa Grads | ||
| @test gs1 .- gs2 isa Grads | ||
| @test .+ gs1 isa Grads | ||
| @test gs1 .+ gs2 isa Grads | ||
| @test 2 .* gs1 isa Grads | ||
| @test gs1 .+ gs2 isa Grads | ||
| @test 2 .* gs1 isa Grads | ||
| @test (2 .* gs1)[w] ≈ 2 * gs1[w] | ||
| @test gs1 .* 2 isa Grads | ||
| @test gs1 ./ 2 isa Grads | ||
| @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
| @test gs1 .* 2 isa Grads | ||
| @test gs1 ./ 2 isa Grads | ||
| @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
|
|
||
| gs12 = gs1 .+ gs2 | ||
| gs1 .+= gs2 | ||
| @test gs12[w] ≈ gs1[w] | ||
| @test gs12[w] ≈ gs1[w] | ||
|
|
||
| gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b | ||
| gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
| gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
|
|
||
| @test .- gs3 isa Grads | ||
| @test gs3 .- gs4 isa Grads | ||
| @test gs3 .- gs4 isa Grads | ||
| @test .+ gs3 isa Grads | ||
| @test gs3 .+ gs4 isa Grads | ||
| @test 2 .* gs3 isa Grads | ||
| @test gs3 .* 2 isa Grads | ||
| @test gs3 ./ 2 isa Grads | ||
| @test gs3 .+ gs4 isa Grads | ||
| @test 2 .* gs3 isa Grads | ||
| @test gs3 .* 2 isa Grads | ||
| @test gs3 ./ 2 isa Grads | ||
| @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] | ||
| @test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
| @test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
|
|
||
| @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads | ||
| gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) | ||
| @test gs3 isa Grads | ||
| @test gs3 isa Grads | ||
|
|
||
| @test_throws ArgumentError gs1 .+ gs4 | ||
| end | ||
|
|
@@ -140,3 +140,21 @@ end | |
| @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} | ||
| end | ||
|
|
||
|
|
||
| @testset "CUDA complex broadcasting" begin | ||
| # Issue 961 and 1121 and 1215 | ||
| x = rand(Float32, 50) | ||
| y = complex(rand(Float32, 50)) | ||
|
||
|
|
||
| xgpu = cu(x) | ||
| ygpu = cu(y) | ||
|
|
||
|
|
||
| g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] | ||
| g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] | ||
| g3 = Zygote.gradient(x->sum(abs2, x), y)[1] | ||
| @test g1 isa CUDA.CuArray{ComplexF32} | ||
| @test g2 isa CUDA.CuArray{ComplexF32} | ||
| @test collect(g1) ≈ collect(g2) | ||
| @test collect(g1) ≈ g3 | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
Union{Dual, Dual{<:Complex}}? You'd have to try pretty hard but I think the Complex path expects Dual inside.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought is was the other way around? At least that is what I am constructing in the
dual_function.ForwardDiff.jlalso definesDual <: Realso I think defining it the other way would break things. However, I probably want to be a little more specific here and doThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry, that's what I was thinking but didn't type...