@@ -73,10 +73,10 @@ mod llvm_enzyme {
7373 }
7474
7575 // Get information about the function the macro is applied to
76- fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident ) > {
76+ fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident , Generics ) > {
7777 match & iitem. kind {
78- ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
79- Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
78+ ItemKind :: Fn ( box ast:: Fn { sig, ident, generics , .. } ) => {
79+ Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics . clone ( ) ) )
8080 }
8181 _ => None ,
8282 }
@@ -210,16 +210,18 @@ mod llvm_enzyme {
210210 }
211211 let dcx = ecx. sess . dcx ( ) ;
212212
213- // first get information about the annotable item:
214- let Some ( ( vis, sig, primal) ) = ( match & item {
213+ // first get information about the annotable item: visibility, signature, name and generic
214+ // parameters.
215+ // these will be used to generate the differentiated version of the function
216+ let Some ( ( vis, sig, primal, generics) ) = ( match & item {
215217 Annotatable :: Item ( iitem) => extract_item_info ( iitem) ,
216218 Annotatable :: Stmt ( stmt) => match & stmt. kind {
217219 ast:: StmtKind :: Item ( iitem) => extract_item_info ( iitem) ,
218220 _ => None ,
219221 } ,
220222 Annotatable :: AssocItem ( assoc_item, Impl { .. } ) => match & assoc_item. kind {
221- ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
222- Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
223+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, generics , .. } ) => {
224+ Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics . clone ( ) ) )
223225 }
224226 _ => None ,
225227 } ,
@@ -303,14 +305,15 @@ mod llvm_enzyme {
303305 let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
304306 let d_body = gen_enzyme_body (
305307 ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
308+ & generics,
306309 ) ;
307310
308311 // The first element of it is the name of the function to be generated
309312 let asdf = Box :: new ( ast:: Fn {
310313 defaultness : ast:: Defaultness :: Final ,
311314 sig : d_sig,
312315 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
313- generics : Generics :: default ( ) ,
316+ generics,
314317 contract : None ,
315318 body : Some ( d_body) ,
316319 define_opaque : None ,
@@ -475,6 +478,7 @@ mod llvm_enzyme {
475478 new_decl_span : Span ,
476479 idents : & [ Ident ] ,
477480 errored : bool ,
481+ generics : & Generics ,
478482 ) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
479483 let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
480484 let noop = ast:: InlineAsm {
@@ -497,7 +501,7 @@ mod llvm_enzyme {
497501 } ;
498502 let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
499503 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
500- let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
504+ let primal_call = gen_primal_call ( ecx, span, primal, idents, generics ) ;
501505 let black_box_primal_call = ecx. expr_call (
502506 new_decl_span,
503507 blackbox_call_expr. clone ( ) ,
@@ -546,6 +550,7 @@ mod llvm_enzyme {
546550 sig_span : Span ,
547551 idents : Vec < Ident > ,
548552 errored : bool ,
553+ generics : & Generics ,
549554 ) -> P < ast:: Block > {
550555 let new_decl_span = d_sig. span ;
551556
@@ -566,6 +571,7 @@ mod llvm_enzyme {
566571 new_decl_span,
567572 & idents,
568573 errored,
574+ generics,
569575 ) ;
570576
571577 if !has_ret ( & d_sig. decl . output ) {
@@ -608,7 +614,6 @@ mod llvm_enzyme {
608614 panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
609615 }
610616 } ;
611-
612617 if x. mode . is_fwd ( ) {
613618 // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
614619 // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -668,8 +673,10 @@ mod llvm_enzyme {
668673 span : Span ,
669674 primal : Ident ,
670675 idents : & [ Ident ] ,
676+ generics : & Generics ,
671677 ) -> P < ast:: Expr > {
672678 let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
679+
673680 if has_self {
674681 let args: ThinVec < _ > =
675682 idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
@@ -678,7 +685,51 @@ mod llvm_enzyme {
678685 } else {
679686 let args: ThinVec < _ > =
680687 idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
681- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
688+ let mut primal_path = ecx. path_ident ( span, primal) ;
689+
690+ let is_generic = !generics. params . is_empty ( ) ;
691+
692+ match ( is_generic, primal_path. segments . last_mut ( ) ) {
693+ ( true , Some ( function_path) ) => {
694+ let primal_generic_types = generics
695+ . params
696+ . iter ( )
697+ . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
698+
699+ let generated_generic_types = primal_generic_types
700+ . map ( |type_param| {
701+ let generic_param = TyKind :: Path (
702+ None ,
703+ ast:: Path {
704+ span,
705+ segments : thin_vec ! [ ast:: PathSegment {
706+ ident: type_param. ident,
707+ args: None ,
708+ id: ast:: DUMMY_NODE_ID ,
709+ } ] ,
710+ tokens : None ,
711+ } ,
712+ ) ;
713+
714+ ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
715+ id : type_param. id ,
716+ span,
717+ kind : generic_param,
718+ tokens : None ,
719+ } ) ) )
720+ } )
721+ . collect ( ) ;
722+
723+ function_path. args =
724+ Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
725+ span,
726+ args : generated_generic_types,
727+ } ) ) ) ;
728+ }
729+ _ => { }
730+ }
731+
732+ let primal_call_expr = ecx. expr_path ( primal_path) ;
682733 ecx. expr_call ( span, primal_call_expr, args)
683734 }
684735 }
0 commit comments