Enable insertion of known derivative in autodiff chain#176
Enable insertion of known derivative in autodiff chain#176KristofferC merged 15 commits intoFerrite-FEM:masterfrom
Conversation
|
Edit: naming outdated Test casesimport Tensors: Dual, _insert_gradient, _extract_value
using Tensors, ForwardDiff
using BenchmarkTools
using Test
if !isdefined(Main, :DualTensor)
const DualTensor=Dual{<:ForwardDiff.Tag{<:Any,<:AbstractTensor}}
end
g1(x)=x
f1(x) = g1(x)
g2(x) = x⋅x
f2(x) = g2(x)
g3(x) = x⊡x
f3(x) = g3(x)
function f1(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
fval = f1(_extract_value(x))
∇f = one(Tensor{4,dim})
return _insert_gradient(fval, ∇f, x)
end
function f2(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
xval = _extract_value(x)
fval = f2(xval)
I2 = one(Tensor{2,dim})
∇f = otimesu(I2, transpose(xval)) + otimesu(xval, I2)
return _insert_gradient(fval, ∇f, x)
end
function f3(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
xval = _extract_value(x)
fval = f3(xval)
∇f = 2*xval
return _insert_gradient(fval, ∇f, x)
end
for dim in 1:3
a = rand(Tensor{2,dim})
for (f,g) in ((f1,g1),(f2,g2),(f3,g3))
println("f = $f, dim = $dim")
@test gradient(f,a)==gradient(g,a)
print("analytical: "); @btime gradient($f, $a);
print("autodiff: "); @btime gradient($g, $a);
println("")
end
endResultsf = f1, dim = 1
analytical: 0.800 ns (0 allocations: 0 bytes)
autodiff: 0.800 ns (0 allocations: 0 bytes)
f = f2, dim = 1
analytical: 6.900 ns (0 allocations: 0 bytes)
autodiff: 0.800 ns (0 allocations: 0 bytes)
f = f3, dim = 1
analytical: 0.800 ns (0 allocations: 0 bytes)
autodiff: 0.800 ns (0 allocations: 0 bytes)
f = f1, dim = 2
analytical: 13.126 ns (0 allocations: 0 bytes)
autodiff: 1.000 ns (0 allocations: 0 bytes)
f = f2, dim = 2
analytical: 23.571 ns (0 allocations: 0 bytes)
autodiff: 17.435 ns (0 allocations: 0 bytes)
f = f3, dim = 2
analytical: 5.000 ns (0 allocations: 0 bytes)
autodiff: 11.400 ns (0 allocations: 0 bytes)
f = f1, dim = 3
analytical: 101.067 ns (0 allocations: 0 bytes)
autodiff: 5.600 ns (0 allocations: 0 bytes)
f = f2, dim = 3
analytical: 132.686 ns (0 allocations: 0 bytes)
autodiff: 81.186 ns (0 allocations: 0 bytes)
f = f3, dim = 3
analytical: 29.335 ns (0 allocations: 0 bytes)
autodiff: 34.211 ns (0 allocations: 0 bytes) |
Codecov Report
@@ Coverage Diff @@
## master #176 +/- ##
==========================================
+ Coverage 97.91% 97.98% +0.06%
==========================================
Files 16 16
Lines 1296 1341 +45
==========================================
+ Hits 1269 1314 +45
Misses 27 27
Continue to review full report at Codecov.
|
|
Would be good with a few tests but otherwise this looks ok. It is a bit awkward to special case the function on |
|
Cool, I'll try to add a few tests tomorrow! |
|
I'm not sure if it is ok that I included allowing the additional operations for the open product, but it allowed me to write fewer functions and I think it makes sense? |
|
|
||
| # Scalar output -> Scalar value | ||
| """ | ||
| function _extract_value(v::ForwardDiff.Dual) |
There was a problem hiding this comment.
If this (and _insert_gradient) are documented, perhaps remove the underscore?
There was a problem hiding this comment.
They are documented, but I didn't include them in the docs.
(As I documented them before using your @implement_gradient idea).
Should I then remove underscore, include in docs, and export them, or just remove underscore?
|
Let's go with this for now :) |
This PR considers creating a mechanism for inserting a known derivative into a function that is part of a chain of functions differentiated using the
Tensors.jl'sgradientmethod.The new method isUsing the suggestion from @KristofferC below, it introduces the macroinsert_gradient, andextract_valuehas also been exported.@implement_gradientThe main goal is to be able to provide analytical derivatives when automatic differentiation fails, see e.g. #174 .
As a side benefit, if a known analytical derivative of part of a function that is slow to call using dual numbers, this method can be used to provide a more efficient implementation (not the case for most functions in
Tensors.jlthough)Updated example (cases in a comment below are outdated, see also docs):
Suggestions to improve the method are very welcome!
In order to work for #174 / #175 specifically, it seems to require support for 3rd order tensors if we don't want an ad-hoc fix for the specific differentiation case (because we differentiate an (eigen)vector wrt. to a 2nd order at some point in the differentiation chain). I guess supporting 3rd order should be a separate PR though#174 can now be solved with this PR if the user manually provides the derivative for their specific function. This PR might be used to solve AD of
Tensors.jl's eigenvector calculation, but requires dealing with 3rd order tensors and is not considered here.Help wanted
_insert_gradient) should be API or remain internal: @fredrikekre suggested to make public.Potential efficiency gains? (in some cases much slower execution time, but also sometimes faster): OK for nowRemaining tasks