-
-
Couldn't load subscription status.
- Fork 216
Adding complex broadcasting for gradients on the GPU #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
807d689
2972faf
51dc882
0635ba4
739e896
a0e21e6
5a83493
6742644
2aa06c6
851ab33
f42d940
95a6b5b
b29f090
40fdb29
15c33ad
5e53ada
2c4857b
9fc2180
c685798
efc4f67
51e3ba3
c888db8
83ed917
7b0044b
2bb3b65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -120,6 +120,9 @@ end | |||||
| @adjoint broadcasted(::typeof(imag), x::Numeric) = | ||||||
| imag.(x), z̄ -> (nothing, im .* real.(z̄)) | ||||||
|
|
||||||
| @adjoint broadcasted(::typeof(abs2), x::Numeric) = | ||||||
| abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) | ||||||
|
|
||||||
| @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) | ||||||
| y = b === false ? a : a .+ b | ||||||
| y, Δ -> (nothing, Δ, nothing) | ||||||
|
|
@@ -190,7 +193,7 @@ _dual_safearg(x) = false | |||||
| # Avoid generic broadcasting in two easy cases: | ||||||
| if T == Bool | ||||||
| return (f.(args...), _ -> nothing) | ||||||
| elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() | ||||||
| elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() | ||||||
| return broadcast_forward(f, args...) | ||||||
| end | ||||||
| len = inclen(args) | ||||||
|
|
@@ -232,23 +235,44 @@ end | |||||
| import ForwardDiff | ||||||
| using ForwardDiff: Dual | ||||||
|
|
||||||
| dual(x, p) = x | ||||||
| dual(x::Real, p) = Dual(x, p) | ||||||
| dual(x::Bool, p) = x | ||||||
|
|
||||||
| # We do this because it ensures type stability so it compiles nicely on the gpu | ||||||
| dual(x, i, N) = x | ||||||
| dual(x::Bool, i, ::Val{N}) where {N} = x | ||||||
| dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(j-> i==j, Val(N))) | ||||||
| # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we | ||||||
| # construct a Complex dual number and tag the real and imaginary parts separately | ||||||
| function dual(x::Complex, i, ::Val{N}) where {N} | ||||||
| re_dual = Dual(real(x), ntuple(j->i==j, Val(2N))) | ||||||
| im_dual = Dual(imag(x), ntuple(j->(N+i)==j, Val(2N))) | ||||||
ptiede marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| return Complex(re_dual, im_dual) | ||||||
| end | ||||||
|
|
||||||
| function dual_function(f::F) where F | ||||||
| function (args::Vararg{Any,N}) where N | ||||||
| ds = map(args, ntuple(identity,Val(N))) do x, i | ||||||
| dual(x, ntuple(j -> i==j, Val(N))) | ||||||
| function (args::Vararg{Any,N}) where N | ||||||
| ds = map(args, ntuple(identity,Val(N))) do x, i | ||||||
| tmp = dual(x, i, Val(N)) | ||||||
CarloLucibello marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| return tmp | ||||||
| end | ||||||
| return f(ds...) | ||||||
| end | ||||||
| return f(ds...) | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
|
|
||||||
| @inline function broadcast_forward(f, args::Vararg{Any,N}) where N | ||||||
| valN = Val(N) | ||||||
| out = dual_function(f).(args...) | ||||||
| eltype(out) <: Dual || return (out, _ -> nothing) | ||||||
| T = eltype(out) | ||||||
| T <: Union{Dual, Complex} || return (out, _ -> nothing) | ||||||
|
||||||
| T <: Union{Dual, Complex} || return (out, _ -> nothing) | |
| T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry, that's what I was thinking but didn't type...
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ end | |
| g_gpu = gradient(x -> v(x, 7), a_gpu)[1] | ||
| @test g_gpu isa CuArray | ||
| @test g_gpu |> collect ≈ g | ||
|
|
||
| w(x) = sum(broadcast(log, x)) | ||
| g = gradient(x -> w(x), a)[1] | ||
| g_gpu = gradient(x -> w(x), a_gpu)[1] | ||
|
|
@@ -38,7 +38,7 @@ end | |
| @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) | ||
| g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression | ||
| @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 | ||
| @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
| @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
|
|
||
| # Projection: eltype preservation: | ||
| @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} | ||
|
|
@@ -90,40 +90,40 @@ end | |
| @testset "gradient algebra" begin | ||
| w, b = rand(2) |> cu, rand(2) |> cu | ||
| x1, x2 = rand(2) |> cu, rand(2) |> cu | ||
| gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
| gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
|
||
| gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
| gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
|
||
| @test .- gs1 isa Grads | ||
| @test gs1 .- gs2 isa Grads | ||
| @test gs1 .- gs2 isa Grads | ||
| @test .+ gs1 isa Grads | ||
| @test gs1 .+ gs2 isa Grads | ||
| @test 2 .* gs1 isa Grads | ||
| @test gs1 .+ gs2 isa Grads | ||
| @test 2 .* gs1 isa Grads | ||
| @test (2 .* gs1)[w] ≈ 2 * gs1[w] | ||
| @test gs1 .* 2 isa Grads | ||
| @test gs1 ./ 2 isa Grads | ||
| @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
| @test gs1 .* 2 isa Grads | ||
| @test gs1 ./ 2 isa Grads | ||
| @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
|
|
||
| gs12 = gs1 .+ gs2 | ||
| gs1 .+= gs2 | ||
| @test gs12[w] ≈ gs1[w] | ||
| @test gs12[w] ≈ gs1[w] | ||
|
|
||
| gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b | ||
| gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
| gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
|
|
||
| @test .- gs3 isa Grads | ||
| @test gs3 .- gs4 isa Grads | ||
| @test gs3 .- gs4 isa Grads | ||
| @test .+ gs3 isa Grads | ||
| @test gs3 .+ gs4 isa Grads | ||
| @test 2 .* gs3 isa Grads | ||
| @test gs3 .* 2 isa Grads | ||
| @test gs3 ./ 2 isa Grads | ||
| @test gs3 .+ gs4 isa Grads | ||
| @test 2 .* gs3 isa Grads | ||
| @test gs3 .* 2 isa Grads | ||
| @test gs3 ./ 2 isa Grads | ||
| @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] | ||
| @test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
| @test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
|
|
||
| @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads | ||
| gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) | ||
| @test gs3 isa Grads | ||
| @test gs3 isa Grads | ||
|
|
||
| @test_throws ArgumentError gs1 .+ gs4 | ||
| end | ||
|
|
@@ -140,3 +140,21 @@ end | |
| @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} | ||
| end | ||
|
|
||
|
|
||
| @testset "CUDA complex broadcasting" begin | ||
| # Issue 961 and 1121 and 1215 | ||
| x = rand(Float32, 50) | ||
| y = complex(rand(Float32, 50)) | ||
|
||
|
|
||
| xgpu = cu(x) | ||
| ygpu = cu(y) | ||
|
|
||
|
|
||
| g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] | ||
| g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] | ||
| g3 = Zygote.gradient(x->sum(abs2, x), y)[1] | ||
| @test g1 isa CUDA.CuArray{ComplexF32} | ||
| @test g2 isa CUDA.CuArray{ComplexF32} | ||
| @test collect(g1) ≈ collect(g2) | ||
| @test collect(g1) ≈ g3 | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.