Skip to content

Commit 715f4bc

Browse files
committed
Only seed structurally non-zero entries
1 parent f3340a3 commit 715f4bc

File tree

4 files changed

+61
-37
lines changed

4 files changed

+61
-37
lines changed

src/apiutils.jl

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,48 +40,63 @@ end
4040
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
4141
end
4242

43-
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
44-
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
45-
duals .= Dual{T,V,N}.(x, Ref(seed))
46-
return duals
47-
end
48-
49-
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
50-
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
51-
dual_inds = 1:N
52-
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
53-
return duals
54-
end
55-
56-
# Triangular matrices
57-
function _nonzero_indices(x::UpperTriangular)
43+
# Only seed indices that are structurally non-zero
44+
_structural_nonzero_indices(x::AbstractArray) = eachindex(x)
45+
function _structural_nonzero_indices(x::UpperTriangular)
5846
n = size(x, 1)
5947
return (CartesianIndex(i, j) for j in 1:n for i in 1:j)
6048
end
61-
function _nonzero_indices(x::LowerTriangular)
49+
function _structural_nonzero_indices(x::LowerTriangular)
6250
n = size(x, 1)
6351
return (CartesianIndex(i, j) for j in 1:n for i in j:n)
6452
end
65-
function seed!(duals::Union{LowerTriangular{Dual{T,V,N}},UpperTriangular{Dual{T,V,N}}}, x, seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
66-
for (idx, seed) in zip(_nonzero_indices(duals), seeds)
53+
_structural_nonzero_indices(x::Diagonal) = diagind(x)
54+
55+
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
56+
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
57+
if eachindex(duals) != eachindex(x)
58+
throw(ArgumentError("indices of input array and array of duals are not identical"))
59+
end
60+
for idx in _structural_nonzero_indices(duals)
6761
duals[idx] = Dual{T,V,N}(x[idx], seed)
6862
end
6963
return duals
7064
end
7165

66+
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
67+
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
68+
if eachindex(duals) != eachindex(x)
69+
throw(ArgumentError("indices of input array and array of duals are not identical"))
70+
end
71+
for (i, idx) in enumerate(_structural_nonzero_indices(duals))
72+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
73+
end
74+
return duals
75+
end
76+
7277
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7378
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
79+
if eachindex(duals) != eachindex(x)
80+
throw(ArgumentError("indices of input array and array of duals are not identical"))
81+
end
7482
offset = index - 1
75-
dual_inds = (1:N) .+ offset
76-
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
83+
idxs = Iterators.drop(_structural_nonzero_indices(duals), offset)
84+
for idx in idxs
85+
duals[idx] = Dual{T,V,N}(x[idx], seed)
86+
end
7787
return duals
7888
end
7989

8090
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
8191
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
92+
if eachindex(duals) != eachindex(x)
93+
throw(ArgumentError("indices of input array and array of duals are not identical"))
94+
end
8295
offset = index - 1
83-
seed_inds = 1:chunksize
84-
dual_inds = seed_inds .+ offset
85-
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
96+
idxs = Iterators.drop(_structural_nonzero_indices(duals), offset)
97+
for (i, idx) in enumerate(idxs)
98+
i > chunksize && break
99+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
100+
end
86101
return duals
87102
end

src/gradient.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,20 @@ function extract_gradient!(::Type{T}, result::DiffResult, dual::Dual) where {T}
6363
end
6464

6565
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
66-
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}= copyto!(result, partials(T, dual))
67-
68-
# Triangular matrices
69-
function extract_gradient!(::Type{T}, result::Union{UpperTriangular,LowerTriangular}, dual::Dual) where {T}
70-
for (idx, p) in zip(_nonzero_indices(result), partials(T, dual))
71-
result[idx] = p
66+
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
67+
idxs = _structural_nonzero_indices(result)
68+
for (i, idx) in enumerate(idxs)
69+
result[idx] = partials(T, dual, i)
7270
end
7371
return result
7472
end
7573

7674
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
7775
offset = index - 1
78-
for i in 1:chunksize
79-
result[i + offset] = partials(T, dual, i)
76+
idxs = Iterators.drop(_structural_nonzero_indices(result), offset)
77+
for (i, idx) in enumerate(idxs)
78+
i > chunksize && break
79+
result[idx] = partials(T, dual, i)
8080
end
8181
return result
8282
end

src/prelude.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@ function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHO
1212
Base.@nif 12 d->(N == d) d->(Chunk{d}()) d->(Chunk{N}())
1313
end
1414

15+
_length_structural_nonzero_indices(x::AbstractArray) = length(x)
16+
function _length_structural_nonzero_indices(x::Union{LowerTriangular,UpperTriangular})
17+
n = size(x, 1)
18+
return (n * (n + 1)) >> 1
19+
end
20+
_length_structural_nonzero_indices(x::Diagonal) = size(x, 1)
21+
1522
function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
16-
return Chunk(length(x), threshold)
23+
return Chunk(_length_structural_nonzero_indices(x), threshold)
1724
end
1825

1926
# Constrained to `N <= threshold`, minimize (in order of priority):

test/GradientTest.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,13 @@ end
227227
end
228228

229229
# issue #738
230-
@testset "LowerTriangular and UpperTriangular" begin
231-
M = rand(3, 3)
232-
for T in (LowerTriangular, UpperTriangular)
233-
@test ForwardDiff.gradient(sum, T(randn(3, 3))) == T(ones(3, 3))
234-
@test ForwardDiff.gradient(x -> dot(M, x), T(randn(3, 3))) == T(M)
230+
@testset "LowerTriangular, UpperTriangular and Diagonal" begin
231+
for n in (3, 10, 20)
232+
M = rand(n, n)
233+
for T in (LowerTriangular, UpperTriangular, Diagonal)
234+
@test ForwardDiff.gradient(sum, T(randn(n, n))) == T(ones(n, n))
235+
@test ForwardDiff.gradient(x -> dot(M, x), T(randn(n, n))) == T(M)
236+
end
235237
end
236238
end
237239

0 commit comments

Comments
 (0)