Skip to content

Commit 3732c3c

Browse files
authored
Rollup merge of #148201 - ZuseZ4:autodiff-activity-docs, r=oli-obk
Start documenting autodiff activities Some initial documentation of the autodiff macros and usage examples
2 parents 714f1ce + f5892da commit 3732c3c

File tree

5 files changed

+112
-0
lines changed

5 files changed

+112
-0
lines changed

library/core/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ optimize_for_size = []
2323
# Make `RefCell` store additional debugging information, which is printed out when
2424
# a borrow error occurs
2525
debug_refcell = []
26+
llvm_enzyme = []
2627

2728
[lints.rust.unexpected_cfgs]
2829
level = "warn"
@@ -38,4 +39,6 @@ check-cfg = [
3839
'cfg(target_has_reliable_f16_math)',
3940
'cfg(target_has_reliable_f128)',
4041
'cfg(target_has_reliable_f128_math)',
42+
'cfg(llvm_enzyme)',
43+
4144
]

library/core/src/macros/mod.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,55 @@ pub(crate) mod builtin {
14991499
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
15001500
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
15011501
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
1502+
///
1503+
/// ACTIVITIES might either be `Dual` or `Const`, more options will be exposed later.
1504+
///
1505+
/// `Const` should be used on non-float arguments, or float-based arguments as an optimization
1506+
/// if we are not interested in computing the derivatives with respect to this argument.
1507+
///
1508+
/// `Dual` can be used for float scalar values or for references, raw pointers, or other
1509+
/// indirect input arguments. It can also be used on a scalar float return value.
1510+
/// If used on a return value, the generated function will return a tuple of two float scalars.
1511+
/// If used on an input argument, a new shadow argument of the same type will be created,
1512+
/// directly following the original argument.
1513+
///
1514+
/// ### Usage examples:
1515+
///
1516+
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
1517+
/// #![feature(autodiff)]
1518+
/// use std::autodiff::*;
1519+
/// #[autodiff_forward(rb_fwd1, Dual, Const, Dual)]
1520+
/// #[autodiff_forward(rb_fwd2, Const, Dual, Dual)]
1521+
/// #[autodiff_forward(rb_fwd3, Dual, Dual, Dual)]
1522+
/// fn rosenbrock(x: f64, y: f64) -> f64 {
1523+
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
1524+
/// }
1525+
/// #[autodiff_forward(rb_inp_fwd, Dual, Dual, Dual)]
1526+
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
1527+
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
1528+
/// }
1529+
///
1530+
/// fn main() {
1531+
/// let x0 = rosenbrock(1.0, 3.0); // 400.0
1532+
/// let (x1, dx1) = rb_fwd1(1.0, 1.0, 3.0); // (400.0, -800.0)
1533+
/// let (x2, dy1) = rb_fwd2(1.0, 3.0, 1.0); // (400.0, 400.0)
1534+
/// // When seeding both arguments at once the tangent return is the sum of both.
1535+
/// let (x3, dxy) = rb_fwd3(1.0, 1.0, 3.0, 1.0); // (400.0, -400.0)
1536+
///
1537+
/// let mut out = 0.0;
1538+
/// let mut dout = 0.0;
1539+
/// rb_inp_fwd(1.0, 1.0, 3.0, 1.0, &mut out, &mut dout);
1540+
/// // (out, dout) == (400.0, -400.0)
1541+
/// }
1542+
/// ```
1543+
///
1544+
/// We might want to track how one input float affects one or more output floats. In this case,
1545+
/// the shadow of one input should be initialized to `1.0`, while the shadows of the other
1546+
/// inputs should be initialized to `0.0`. The shadow of the output(s) should be initialized to
1547+
/// `0.0`. After calling the generated function, the shadow of the input will be zeroed,
1548+
/// while the shadow(s) of the output(s) will contain the derivatives. Forward mode is generally
1549+
/// more efficient if we have more output floats marked as `Dual` than input floats.
1550+
/// Related information can also be found under the term "Vector-Jacobian product" (VJP).
15021551
#[unstable(feature = "autodiff", issue = "124509")]
15031552
#[allow_internal_unstable(rustc_attrs)]
15041553
#[allow_internal_unstable(core_intrinsics)]
@@ -1518,6 +1567,60 @@ pub(crate) mod builtin {
15181567
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
15191568
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
15201569
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
1570+
///
1571+
/// ACTIVITIES might either be `Active`, `Duplicated` or `Const`, more options will be exposed later.
1572+
///
1573+
/// `Active` can be used for float scalar values.
1574+
/// If used on an input, a new float will be appended to the return tuple of the generated
1575+
/// function. If the function returns a float scalar, `Active` can be used for the return as
1576+
/// well. In this case a float scalar will be appended to the argument list, it works as seed.
1577+
///
1578+
/// `Duplicated` can be used on references, raw pointers, or other indirect input
1579+
/// arguments. It creates a new shadow argument of the same type, following the original argument.
1580+
/// A const reference or pointer argument will receive a mutable reference or pointer as shadow.
1581+
///
1582+
/// `Const` should be used on non-float arguments, or float-based arguments as an optimization
1583+
/// if we are not interested in computing the derivatives with respect to this argument.
1584+
///
1585+
/// ### Usage examples:
1586+
///
1587+
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
1588+
/// #![feature(autodiff)]
1589+
/// use std::autodiff::*;
1590+
/// #[autodiff_reverse(rb_rev, Active, Active, Active)]
1591+
/// fn rosenbrock(x: f64, y: f64) -> f64 {
1592+
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
1593+
/// }
1594+
/// #[autodiff_reverse(rb_inp_rev, Active, Active, Duplicated)]
1595+
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
1596+
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
1597+
/// }
1598+
///
1599+
/// fn main() {
1600+
/// let (output1, dx1, dy1) = rb_rev(1.0, 3.0, 1.0);
1601+
/// dbg!(output1, dx1, dy1); // (400.0, -800.0, 400.0)
1602+
/// let mut output2 = 0.0;
1603+
/// let mut seed = 1.0;
1604+
/// let (dx2, dy2) = rb_inp_rev(1.0, 3.0, &mut output2, &mut seed);
1605+
/// // (dx2, dy2, output2, seed) == (-800.0, 400.0, 400.0, 0.0)
1606+
/// }
1607+
/// ```
1608+
///
1609+
///
1610+
/// We often want to track how one or more input floats affect one output float. This output can
1611+
/// be a scalar return value, or a mutable reference or pointer argument. In the latter case, the
1612+
/// mutable input should be marked as duplicated and its shadow initialized to `0.0`. The shadow of
1613+
/// the output should be marked as active or duplicated and initialized to `1.0`. After calling
1614+
/// the generated function, the shadow(s) of the input(s) will contain the derivatives. The
1615+
/// shadow of the outputs ("seed") will be reset to zero.
1616+
/// If the function has more than one output float marked as active or duplicated, users might want to
1617+
/// set one of them to `1.0` and the others to `0.0` to compute partial derivatives.
1618+
/// Unlike forward-mode, a call to the generated function does not reset the shadow of the
1619+
/// inputs.
1620+
/// Reverse mode is generally more efficient if we have more active/duplicated input than
1621+
/// output floats.
1622+
///
1623+
/// Related information can also be found under the term "Jacobian-Vector Product" (JVP).
15211624
#[unstable(feature = "autodiff", issue = "124509")]
15221625
#[allow_internal_unstable(rustc_attrs)]
15231626
#[allow_internal_unstable(core_intrinsics)]

library/std/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ optimize_for_size = ["core/optimize_for_size", "alloc/optimize_for_size"]
126126
# a borrow error occurs
127127
debug_refcell = ["core/debug_refcell"]
128128

129+
llvm_enzyme = ["core/llvm_enzyme"]
129130

130131
# Enable std_detect features:
131132
std_detect_file_io = ["std_detect/std_detect_file_io"]

library/sysroot/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ profiler = ["dep:profiler_builtins"]
3535
std_detect_file_io = ["std/std_detect_file_io"]
3636
std_detect_dlsym_getauxval = ["std/std_detect_dlsym_getauxval"]
3737
windows_raw_dylib = ["std/windows_raw_dylib"]
38+
llvm_enzyme = ["std/llvm_enzyme"]

src/bootstrap/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,10 @@ impl Build {
846846
features.insert("compiler-builtins-mem");
847847
}
848848

849+
if self.config.llvm_enzyme {
850+
features.insert("llvm_enzyme");
851+
}
852+
849853
features.into_iter().collect::<Vec<_>>().join(" ")
850854
}
851855

0 commit comments

Comments
 (0)