Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions library/core/src/autodiff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
This module provides support for automatic differentiation. For precise information on
differences between the `autodiff_forward` and `autodiff_reverse` macros and how to
use them, see their respective documentation.

## General usage

Autodiff macros can be applied to almost all function definitions, see below for examples.
They can be applied to functions accepting structs, arrays, slices, vectors, tuples, and more.

It is possible to apply multiple autodiff macros to the same function. As an example, this can
be helpful to compute the partial derivatives with respect to `x` and `y` independently:
```rust,ignore (optional component)
#[autodiff_forward(dsquare1, Dual, Const, Dual)]
#[autodiff_forward(dsquare2, Const, Dual, Dual)]
#[autodiff_forward(dsquare3, Active, Active, Active)]
fn square(x: f64, y: f64) -> f64 {
x * x + 2.0 * y
}
```

We also support autodiff on functions with generic parameters:
```rust,ignore (optional component)
#[autodiff_forward(generic_derivative, Duplicated, Active)]
fn generic_f<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
x * x
}
```

or applying autodiff to nested functions:
```rust,ignore (optional component)
fn outer(x: f64) -> f64 {
#[autodiff_forward(inner_derivative, Dual, Const)]
fn inner(y: f64) -> f64 {
y * y
}
inner_derivative(x, 1.0)
}

fn main() {
assert_eq!(outer(3.14), 6.28);
}
```
The generated function will be available in the same scope as the function differentiated, and
have the same private/pub usability.

## Traits and impls
Autodiff macros can be used in multiple ways in combination with traits:
```rust,ignore (optional component)
struct Foo {
a: f64,
}

trait MyTrait {
#[autodiff_reverse(df, Const, Active, Active)]
fn f(&self, x: f64) -> f64;
}

impl MyTrait for Foo {
fn f(&self, x: f64) -> f64 {
x.sin()
}
}

fn main() {
let foo = Foo { a: 3.0f64 };
assert_eq!(foo.f(2.0), 2.0_f64.sin());
assert_eq!(foo.df(2.0, 1.0).1, 2.0_f64.cos());
}
```
In this case `df` will be the default implementation provided by the library who provided the
trait. A user implementing `MyTrait` could then decide to use the default implementation of
`df`, or overwrite it with a custom implementation as a form of "custom derivatives".

On the other hand, a function generated by either autodiff macro can also be used to implement a
trait:
```rust,ignore (optional component)
struct Foo {
a: f64,
}

trait MyTrait {
fn f(&self, x: f64) -> f64;
fn df(&self, x: f64, seed: f64) -> (f64, f64);
}

impl MyTrait for Foo {
#[autodiff_reverse(df, Const, Active, Active)]
fn f(&self, x: f64) -> f64 {
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
}
}
```

Simple `impl` blocks without traits are also supported. Differentiating with respect to the
implemented struct will then require the use of a "shadow struct" to hold the derivatives of the
struct fields:

```rust,ignore (optional component)
struct OptProblem {
a: f64,
b: f64,
}

impl OptProblem {
#[autodiff_reverse(d_objective, Duplicated, Duplicated, Duplicated)]
fn objective(&self, x: &[f64], out: &mut f64) {
*out = self.a + x[0].sqrt() * self.b
}
}
fn main() {
let p = OptProblem { a: 1., b: 2. };
let mut p_shadow = OptProblem { a: 0., b: 0. };
let mut dx = [0.0];
let mut out = 0.0;
let mut dout = 1.0;

p.d_objective(&mut p_shadow, &x, &mut dx, &mut out, &mut dout);
}
```

## Higher-order derivatives
Finally, it is possible to generate higher-order derivatives (e.g. Hessian) by applying an
autodiff macro to a function that is already generated by an autodiff macro, via a thin wrapper.
The following example uses Forward mode over Reverse mode

```rust,ignore (optional component)
#[autodiff_reverse(df, Duplicated, Duplicated)]
fn f(x: &[f64;2], y: &mut f64) {
*y = x[0] * x[0] + x[1] * x[0]
}

#[autodiff_forward(h, Dual, Dual, Dual, Dual)]
fn wrapper(x: &[f64;2], dx: &mut [f64;2], y: &mut f64, dy: &mut f64) {
df(x, dx, y, dy);
}

fn main() {
let mut y = 0.0;
let x = [2.0, 2.0];

let mut dy = 0.0;
let mut dx = [1.0, 0.0];

let mut bx = [0.0, 0.0];
let mut by = 1.0;
let mut dbx = [0.0, 0.0];
let mut dby = 0.0;
h(&x, &mut dx, &mut bx, &mut dbx, &mut y, &mut dy, &mut by, &mut dby);
assert_eq!(&dbx, [2.0, 1.0]);
}
```

## Current limitations:

- Differentiating a function which accepts a `dyn Trait` is currently not supported.
- Builds without `lto="fat"` are not yet supported.
- Builds in debug mode are currently more likely to fail compilation.
2 changes: 1 addition & 1 deletion library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ pub mod from {

// We don't export this through #[macro_export] for now, to avoid breakage.
#[unstable(feature = "autodiff", issue = "124509")]
/// Unstable module containing the unstable `autodiff` macro.
#[doc = include_str!("../../core/src/autodiff.md")]
pub mod autodiff {
#[unstable(feature = "autodiff", issue = "124509")]
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
Expand Down
2 changes: 1 addition & 1 deletion library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ pub mod simd {
}

#[unstable(feature = "autodiff", issue = "124509")]
/// This module provides support for automatic differentiation.
#[doc = include_str!("../../core/src/autodiff.md")]
pub mod autodiff {
/// This macro handles automatic differentiation.
pub use core::autodiff::{autodiff_forward, autodiff_reverse};
Expand Down
Loading