diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 264f797a78250..2021acf4aca65 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -214,7 +214,7 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { + let Some((vis, sig, primal, generics, is_impl)) = (match &item { Annotatable::Item(iitem) => { extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) } @@ -224,13 +224,13 @@ mod llvm_enzyme { } _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + Annotatable::AssocItem(assoc_item, Impl { of_trait: _ }) => match &assoc_item.kind { ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone(), - *of_trait, + true, )), _ => None, }, @@ -328,7 +328,7 @@ mod llvm_enzyme { span, &d_sig, &generics, - impl_of_trait, + is_impl, )], ); diff --git a/tests/codegen-llvm/autodiff/impl.rs b/tests/codegen-llvm/autodiff/impl.rs new file mode 100644 index 0000000000000..185ea6af52e0f --- /dev/null +++ b/tests/codegen-llvm/autodiff/impl.rs @@ -0,0 +1,36 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[derive(Clone)] +struct OptProblem { + a: f64, + b: f64, +} + +impl OptProblem { + #[autodiff_reverse(d_objective, Duplicated, Duplicated, Duplicated)] + fn objective(&self, x: &[f64], out: &mut f64) { + *out = self.a + x[0].sqrt() * self.b + } +} + +fn main() { + let p = OptProblem { a: 1., b: 2. }; + let x = [2.0]; + + let mut p_shadow = OptProblem { a: 0., b: 0. }; + let mut dx = [0.0]; + let mut out = 0.0; + let mut dout = 1.0; + + p.d_objective(&mut p_shadow, &x, &mut dx, &mut out, &mut dout); + + dbg!(dx); +}