Skip to content

Enable insertion of known derivative in autodiff chain#176

Merged
KristofferC merged 15 commits intoFerrite-FEM:masterfrom
KnutAM:kam/anadiff_in_autodiff
Jan 26, 2022
Merged

Enable insertion of known derivative in autodiff chain#176
KristofferC merged 15 commits intoFerrite-FEM:masterfrom
KnutAM:kam/anadiff_in_autodiff

Conversation

@KnutAM
Copy link
Member

@KnutAM KnutAM commented Jan 14, 2022

This PR considers creating a mechanism for inserting a known derivative into a function that is part of a chain of functions differentiated using the Tensors.jl's gradient method. The new method is insert_gradient, and extract_value has also been exported. Using the suggestion from @KristofferC below, it introduces the macro @implement_gradient

The main goal is to be able to provide analytical derivatives when automatic differentiation fails, see e.g. #174 .
As a side benefit, if a known analytical derivative of part of a function that is slow to call using dual numbers, this method can be used to provide a more efficient implementation (not the case for most functions in Tensors.jl though)

Updated example (cases in a comment below are outdated, see also docs):

using Tensors
# The function to be given a known gradient
f(x) = xx
# Define function to calculate both value and analytical derivative
function f_dfdx(x::Tensor{2,dim}) where{dim}
    fval = f(x)
    I2 = one(Tensor{2,dim})
    df = otimesu(I2, transpose(x)) + otimesu(x, I2)
    return fval, df
end
# Connect the analytical derivative for `f` to `f_dfdx`
@implement_gradient f f_dfdx

Suggestions to improve the method are very welcome!

In order to work for #174 / #175 specifically, it seems to require support for 3rd order tensors if we don't want an ad-hoc fix for the specific differentiation case (because we differentiate an (eigen)vector wrt. to a 2nd order at some point in the differentiation chain). I guess supporting 3rd order should be a separate PR though

#174 can now be solved with this PR if the user manually provides the derivative for their specific function. This PR might be used to solve AD of Tensors.jl's eigenvector calculation, but requires dealing with 3rd order tensors and is not considered here.

Help wanted

  • Decide if functionality (_insert_gradient) should be API or remain internal: @fredrikekre suggested to make public.
  • Potential efficiency gains? (in some cases much slower execution time, but also sometimes faster): OK for now

Remaining tasks

  • Support symmetric tensors
  • Add tests
  • Add documentation

@KnutAM
Copy link
Member Author

KnutAM commented Jan 14, 2022

Edit: naming outdated

Test cases

import Tensors: Dual, _insert_gradient, _extract_value
using Tensors, ForwardDiff
using BenchmarkTools
using Test

if !isdefined(Main, :DualTensor)
    const DualTensor=Dual{<:ForwardDiff.Tag{<:Any,<:AbstractTensor}}
end

g1(x)=x
f1(x) = g1(x)

g2(x) = xx
f2(x) = g2(x)

g3(x) = xx
f3(x) = g3(x)

function f1(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
    fval = f1(_extract_value(x))
    ∇f = one(Tensor{4,dim})
    return _insert_gradient(fval, ∇f, x)
end

function f2(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
    xval = _extract_value(x)
    fval = f2(xval)
    I2 = one(Tensor{2,dim})
    ∇f = otimesu(I2, transpose(xval)) + otimesu(xval, I2)
    return _insert_gradient(fval, ∇f, x)
end

function f3(x::Tensor{2,dim,T}) where{dim, T<:DualTensor}
    xval = _extract_value(x)
    fval = f3(xval)
    ∇f = 2*xval
    return _insert_gradient(fval, ∇f, x)
end

for dim in 1:3
    a = rand(Tensor{2,dim})
    for (f,g) in ((f1,g1),(f2,g2),(f3,g3))
        println("f = $f, dim = $dim")
        @test gradient(f,a)==gradient(g,a)
        print("analytical: "); @btime gradient($f, $a); 
        print("autodiff:   "); @btime gradient($g, $a);
        println("")
    end
end

Results

f = f1, dim = 1
analytical:   0.800 ns (0 allocations: 0 bytes)
autodiff:     0.800 ns (0 allocations: 0 bytes)

f = f2, dim = 1
analytical:   6.900 ns (0 allocations: 0 bytes)
autodiff:     0.800 ns (0 allocations: 0 bytes)

f = f3, dim = 1
analytical:   0.800 ns (0 allocations: 0 bytes)
autodiff:     0.800 ns (0 allocations: 0 bytes)

f = f1, dim = 2
analytical:   13.126 ns (0 allocations: 0 bytes)
autodiff:     1.000 ns (0 allocations: 0 bytes)

f = f2, dim = 2
analytical:   23.571 ns (0 allocations: 0 bytes)
autodiff:     17.435 ns (0 allocations: 0 bytes)

f = f3, dim = 2
analytical:   5.000 ns (0 allocations: 0 bytes)
autodiff:     11.400 ns (0 allocations: 0 bytes)

f = f1, dim = 3
analytical:   101.067 ns (0 allocations: 0 bytes)
autodiff:     5.600 ns (0 allocations: 0 bytes)

f = f2, dim = 3
analytical:   132.686 ns (0 allocations: 0 bytes)
autodiff:     81.186 ns (0 allocations: 0 bytes)

f = f3, dim = 3
analytical:   29.335 ns (0 allocations: 0 bytes)
autodiff:     34.211 ns (0 allocations: 0 bytes)

@codecov-commenter
Copy link

codecov-commenter commented Jan 14, 2022

Codecov Report

Merging #176 (44cc54e) into master (c3b72b3) will increase coverage by 0.06%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #176      +/-   ##
==========================================
+ Coverage   97.91%   97.98%   +0.06%     
==========================================
  Files          16       16              
  Lines        1296     1341      +45     
==========================================
+ Hits         1269     1314      +45     
  Misses         27       27              
Impacted Files Coverage Δ
src/Tensors.jl 81.81% <ø> (ø)
src/automatic_differentiation.jl 99.03% <100.00%> (+0.24%) ⬆️
src/tensor_products.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c3b72b3...44cc54e. Read the comment docs.

@KristofferC
Copy link
Collaborator

Would be good with a few tests but otherwise this looks ok. It is a bit awkward to special case the function on Dual. A possible API could be similar to what I did in JuliaDiff/ForwardDiff.jl#165. But this is ok as well.

@KnutAM
Copy link
Member Author

KnutAM commented Jan 23, 2022

Cool, I'll try to add a few tests tomorrow!
Do you mean to provide a macro like @implement_gradient? I think that looks nice!
As a novice Julia programmer, however, I'm a bit concerned that users (like me) can find debugging harder when using macros, but perhaps that is unwarranted?

@KnutAM KnutAM changed the title WIP: Enable insertion of known derivative in autodiff chain Enable insertion of known derivative in autodiff chain Jan 25, 2022
@KnutAM KnutAM marked this pull request as ready for review January 25, 2022 20:31
@KnutAM
Copy link
Member Author

KnutAM commented Jan 25, 2022

I'm not sure if it is ok that I included allowing the additional operations for the open product, but it allowed me to write fewer functions and I think it makes sense?

Copy link
Collaborator

@KristofferC KristofferC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice job!


# Scalar output -> Scalar value
"""
function _extract_value(v::ForwardDiff.Dual)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this (and _insert_gradient) are documented, perhaps remove the underscore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are documented, but I didn't include them in the docs.
(As I documented them before using your @implement_gradient idea).
Should I then remove underscore, include in docs, and export them, or just remove underscore?

@KristofferC
Copy link
Collaborator

Let's go with this for now :)

@KristofferC KristofferC merged commit 8f67d27 into Ferrite-FEM:master Jan 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants