@@ -17,7 +17,7 @@ mod llvm_enzyme {
1717    use  rustc_ast:: visit:: AssocCtxt :: * ; 
1818    use  rustc_ast:: { 
1919        self  as  ast,  AssocItemKind ,  BindingMode ,  ExprKind ,  FnRetTy ,  FnSig ,  Generics ,  ItemKind , 
20-         MetaItemInner ,  PatKind ,  QSelf ,  TyKind , 
20+         MetaItemInner ,  PatKind ,  QSelf ,  TyKind ,   Visibility , 
2121    } ; 
2222    use  rustc_expand:: base:: { Annotatable ,  ExtCtxt } ; 
2323    use  rustc_span:: { Ident ,  Span ,  Symbol ,  kw,  sym} ; 
@@ -72,6 +72,16 @@ mod llvm_enzyme {
7272        } 
7373    } 
7474
75+     // Get information about the function the macro is applied to 
76+     fn  extract_item_info ( iitem :  & P < ast:: Item > )  -> Option < ( Visibility ,  FnSig ,  Ident ) >  { 
77+         match  & iitem. kind  { 
78+             ItemKind :: Fn ( box ast:: Fn  {  sig,  ident,  .. } )  => { 
79+                 Some ( ( iitem. vis . clone ( ) ,  sig. clone ( ) ,  ident. clone ( ) ) ) 
80+             } 
81+             _ => None , 
82+         } 
83+     } 
84+ 
7585    pub ( crate )  fn  from_ast ( 
7686        ecx :  & mut  ExtCtxt < ' _ > , 
7787        meta_item :  & ThinVec < MetaItemInner > , 
@@ -199,32 +209,26 @@ mod llvm_enzyme {
199209            return  vec ! [ item] ; 
200210        } 
201211        let  dcx = ecx. sess . dcx ( ) ; 
202-         // first get the annotable item: 
203-         let  ( primal,  sig,  is_impl) :  ( Ident ,  FnSig ,  bool )  = match  & item { 
204-             Annotatable :: Item ( iitem)  => { 
205-                 let  ( ident,  sig)  = match  & iitem. kind  { 
206-                     ItemKind :: Fn ( box ast:: Fn  {  ident,  sig,  .. } )  => ( ident,  sig) , 
207-                     _ => { 
208-                         dcx. emit_err ( errors:: AutoDiffInvalidApplication  {  span :  item. span ( )  } ) ; 
209-                         return  vec ! [ item] ; 
210-                     } 
211-                 } ; 
212-                 ( * ident,  sig. clone ( ) ,  false ) 
213-             } 
212+ 
213+         // first get information about the annotable item: 
214+         let  Some ( ( vis,  sig,  primal) )  = ( match  & item { 
215+             Annotatable :: Item ( iitem)  => extract_item_info ( iitem) , 
216+             Annotatable :: Stmt ( stmt)  => match  & stmt. kind  { 
217+                 ast:: StmtKind :: Item ( iitem)  => extract_item_info ( iitem) , 
218+                 _ => None , 
219+             } , 
214220            Annotatable :: AssocItem ( assoc_item,  Impl  {  of_trait :  false  } )  => { 
215-                 let  ( ident,  sig)  = match  & assoc_item. kind  { 
216-                     ast:: AssocItemKind :: Fn ( box ast:: Fn  {  ident,  sig,  .. } )  => ( ident,  sig) , 
217-                     _ => { 
218-                         dcx. emit_err ( errors:: AutoDiffInvalidApplication  {  span :  item. span ( )  } ) ; 
219-                         return  vec ! [ item] ; 
221+                 match  & assoc_item. kind  { 
222+                     ast:: AssocItemKind :: Fn ( box ast:: Fn  {  sig,  ident,  .. } )  => { 
223+                         Some ( ( assoc_item. vis . clone ( ) ,  sig. clone ( ) ,  ident. clone ( ) ) ) 
220224                    } 
221-                 } ; 
222-                 ( * ident,  sig. clone ( ) ,  true ) 
223-             } 
224-             _ => { 
225-                 dcx. emit_err ( errors:: AutoDiffInvalidApplication  {  span :  item. span ( )  } ) ; 
226-                 return  vec ! [ item] ; 
225+                     _ => None , 
226+                 } 
227227            } 
228+             _ => None , 
229+         } )  else  { 
230+             dcx. emit_err ( errors:: AutoDiffInvalidApplication  {  span :  item. span ( )  } ) ; 
231+             return  vec ! [ item] ; 
228232        } ; 
229233
230234        let  meta_item_vec:  ThinVec < MetaItemInner >  = match  meta_item. kind  { 
@@ -238,15 +242,6 @@ mod llvm_enzyme {
238242        let  has_ret = has_ret ( & sig. decl . output ) ; 
239243        let  sig_span = ecx. with_call_site_ctxt ( sig. span ) ; 
240244
241-         let  vis = match  & item { 
242-             Annotatable :: Item ( iitem)  => iitem. vis . clone ( ) , 
243-             Annotatable :: AssocItem ( assoc_item,  _)  => assoc_item. vis . clone ( ) , 
244-             _ => { 
245-                 dcx. emit_err ( errors:: AutoDiffInvalidApplication  {  span :  item. span ( )  } ) ; 
246-                 return  vec ! [ item] ; 
247-             } 
248-         } ; 
249- 
250245        // create TokenStream from vec elemtents: 
251246        // meta_item doesn't have a .tokens field 
252247        let  mut  ts:  Vec < TokenTree >  = vec ! [ ] ; 
@@ -379,6 +374,22 @@ mod llvm_enzyme {
379374                } 
380375                Annotatable :: AssocItem ( assoc_item. clone ( ) ,  i) 
381376            } 
377+             Annotatable :: Stmt ( ref  mut  stmt)  => { 
378+                 match  stmt. kind  { 
379+                     ast:: StmtKind :: Item ( ref  mut  iitem)  => { 
380+                         if  !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind ,  & attr. kind ) )  { 
381+                             iitem. attrs . push ( attr) ; 
382+                         } 
383+                         if  !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind ,  & inline_never. kind ) ) 
384+                         { 
385+                             iitem. attrs . push ( inline_never. clone ( ) ) ; 
386+                         } 
387+                     } 
388+                     _ => unreachable ! ( "stmt kind checked previously" ) , 
389+                 } ; 
390+ 
391+                 Annotatable :: Stmt ( stmt. clone ( ) ) 
392+             } 
382393            _ => { 
383394                unreachable ! ( "annotatable kind checked previously" ) 
384395            } 
@@ -389,22 +400,40 @@ mod llvm_enzyme {
389400            delim :  rustc_ast:: token:: Delimiter :: Parenthesis , 
390401            tokens :  ts, 
391402        } ) ; 
403+ 
392404        let  d_attr = outer_normal_attr ( & rustc_ad_attr,  new_id,  span) ; 
393-         let  d_annotatable = if  is_impl { 
394-             let  assoc_item:  AssocItemKind  = ast:: AssocItemKind :: Fn ( asdf) ; 
395-             let  d_fn = P ( ast:: AssocItem  { 
396-                 attrs :  thin_vec ! [ d_attr,  inline_never] , 
397-                 id :  ast:: DUMMY_NODE_ID , 
398-                 span, 
399-                 vis, 
400-                 kind :  assoc_item, 
401-                 tokens :  None , 
402-             } ) ; 
403-             Annotatable :: AssocItem ( d_fn,  Impl  {  of_trait :  false  } ) 
404-         }  else  { 
405-             let  mut  d_fn = ecx. item ( span,  thin_vec ! [ d_attr,  inline_never] ,  ItemKind :: Fn ( asdf) ) ; 
406-             d_fn. vis  = vis; 
407-             Annotatable :: Item ( d_fn) 
405+         let  d_annotatable = match  & item { 
406+             Annotatable :: AssocItem ( _,  _)  => { 
407+                 let  assoc_item:  AssocItemKind  = ast:: AssocItemKind :: Fn ( asdf) ; 
408+                 let  d_fn = P ( ast:: AssocItem  { 
409+                     attrs :  thin_vec ! [ d_attr,  inline_never] , 
410+                     id :  ast:: DUMMY_NODE_ID , 
411+                     span, 
412+                     vis, 
413+                     kind :  assoc_item, 
414+                     tokens :  None , 
415+                 } ) ; 
416+                 Annotatable :: AssocItem ( d_fn,  Impl  {  of_trait :  false  } ) 
417+             } 
418+             Annotatable :: Item ( _)  => { 
419+                 let  mut  d_fn = ecx. item ( span,  thin_vec ! [ d_attr,  inline_never] ,  ItemKind :: Fn ( asdf) ) ; 
420+                 d_fn. vis  = vis; 
421+ 
422+                 Annotatable :: Item ( d_fn) 
423+             } 
424+             Annotatable :: Stmt ( _)  => { 
425+                 let  mut  d_fn = ecx. item ( span,  thin_vec ! [ d_attr,  inline_never] ,  ItemKind :: Fn ( asdf) ) ; 
426+                 d_fn. vis  = vis; 
427+ 
428+                 Annotatable :: Stmt ( P ( ast:: Stmt  { 
429+                     id :  ast:: DUMMY_NODE_ID , 
430+                     kind :  ast:: StmtKind :: Item ( d_fn) , 
431+                     span, 
432+                 } ) ) 
433+             } 
434+             _ => { 
435+                 unreachable ! ( "item kind checked previously" ) 
436+             } 
408437        } ; 
409438
410439        return  vec ! [ orig_annotatable,  d_annotatable] ; 
0 commit comments