@@ -6,6 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
66
77use crate :: utils:: Ctx ;
88use crate :: {
9+ attributes,
910 attributes:: { FromPyWithAttribute , TextSignatureAttribute , TextSignatureAttributeValue } ,
1011 deprecations:: { Deprecation , Deprecations } ,
1112 params:: { impl_arg_params, Holders } ,
@@ -379,6 +380,7 @@ pub struct FnSpec<'a> {
379380 pub asyncness : Option < syn:: Token ![ async ] > ,
380381 pub unsafety : Option < syn:: Token ![ unsafe ] > ,
381382 pub deprecations : Deprecations < ' a > ,
383+ pub allow_threads : Option < attributes:: kw:: allow_threads > ,
382384}
383385
384386pub fn parse_method_receiver ( arg : & syn:: FnArg ) -> Result < SelfType > {
@@ -416,6 +418,7 @@ impl<'a> FnSpec<'a> {
416418 text_signature,
417419 name,
418420 signature,
421+ allow_threads,
419422 ..
420423 } = options;
421424
@@ -461,6 +464,7 @@ impl<'a> FnSpec<'a> {
461464 asyncness : sig. asyncness ,
462465 unsafety : sig. unsafety ,
463466 deprecations,
467+ allow_threads,
464468 } )
465469 }
466470
@@ -603,6 +607,21 @@ impl<'a> FnSpec<'a> {
603607 bail_spanned ! ( name. span( ) => "`cancel_handle` may only be specified once" ) ;
604608 }
605609 }
610+ if let Some ( FnArg :: Py ( py_arg) ) = self
611+ . signature
612+ . arguments
613+ . iter ( )
614+ . find ( |arg| matches ! ( arg, FnArg :: Py ( _) ) )
615+ {
616+ ensure_spanned ! (
617+ self . asyncness. is_none( ) ,
618+ py_arg. ty. span( ) => "GIL token cannot be passed to async function"
619+ ) ;
620+ ensure_spanned ! (
621+ self . allow_threads. is_none( ) ,
622+ py_arg. ty. span( ) => "GIL cannot be held in function annotated with `allow_threads`"
623+ ) ;
624+ }
606625
607626 if self . asyncness . is_some ( ) {
608627 ensure_spanned ! (
@@ -612,8 +631,21 @@ impl<'a> FnSpec<'a> {
612631 }
613632
614633 let rust_call = |args : Vec < TokenStream > , holders : & mut Holders | {
615- let mut self_arg = || self . tp . self_arg ( cls, ExtractErrorMode :: Raise , holders, ctx) ;
616-
634+ let allow_threads = self . allow_threads . is_some ( ) ;
635+ let mut self_arg = || {
636+ let self_arg = self . tp . self_arg ( cls, ExtractErrorMode :: Raise , holders, ctx) ;
637+ if self_arg. is_empty ( ) {
638+ self_arg
639+ } else {
640+ let self_checker = holders. push_gil_refs_checker ( self_arg. span ( ) ) ;
641+ quote ! {
642+ #pyo3_path:: impl_:: deprecations:: inspect_type( #self_arg & #self_checker) ,
643+ }
644+ }
645+ } ;
646+ let arg_names = ( 0 ..args. len ( ) )
647+ . map ( |i| format_ident ! ( "arg_{}" , i) )
648+ . collect :: < Vec < _ > > ( ) ;
617649 let call = if self . asyncness . is_some ( ) {
618650 let throw_callback = if cancel_handle. is_some ( ) {
619651 quote ! { Some ( __throw_callback) }
@@ -625,9 +657,6 @@ impl<'a> FnSpec<'a> {
625657 Some ( cls) => quote ! ( Some ( <#cls as #pyo3_path:: PyTypeInfo >:: NAME ) ) ,
626658 None => quote ! ( None ) ,
627659 } ;
628- let arg_names = ( 0 ..args. len ( ) )
629- . map ( |i| format_ident ! ( "arg_{}" , i) )
630- . collect :: < Vec < _ > > ( ) ;
631660 let future = match self . tp {
632661 FnType :: Fn ( SelfType :: Receiver { mutable : false , .. } ) => {
633662 quote ! { {
@@ -645,18 +674,7 @@ impl<'a> FnSpec<'a> {
645674 }
646675 _ => {
647676 let self_arg = self_arg ( ) ;
648- if self_arg. is_empty ( ) {
649- quote ! { function( #( #args) , * ) }
650- } else {
651- let self_checker = holders. push_gil_refs_checker ( self_arg. span ( ) ) ;
652- quote ! {
653- function(
654- // NB #self_arg includes a comma, so none inserted here
655- #pyo3_path:: impl_:: deprecations:: inspect_type( #self_arg & #self_checker) ,
656- #( #args) , *
657- )
658- }
659- }
677+ quote ! ( function( #self_arg #( #args) , * ) )
660678 }
661679 } ;
662680 let mut call = quote ! { {
@@ -665,6 +683,7 @@ impl<'a> FnSpec<'a> {
665683 #pyo3_path:: intern!( py, stringify!( #python_name) ) ,
666684 #qualname_prefix,
667685 #throw_callback,
686+ #allow_threads,
668687 async move { #pyo3_path:: impl_:: wrap:: OkWrap :: wrap( future. await ) } ,
669688 )
670689 } } ;
@@ -676,20 +695,21 @@ impl<'a> FnSpec<'a> {
676695 } } ;
677696 }
678697 call
679- } else {
698+ } else if allow_threads {
680699 let self_arg = self_arg ( ) ;
681- if self_arg. is_empty ( ) {
682- quote ! { function ( # ( #args ) , * ) }
700+ let ( self_arg_name , self_arg_decl ) = if self_arg. is_empty ( ) {
701+ ( quote ! ( ) , quote ! ( ) )
683702 } else {
684- let self_checker = holders. push_gil_refs_checker ( self_arg. span ( ) ) ;
685- quote ! {
686- function(
687- // NB #self_arg includes a comma, so none inserted here
688- #pyo3_path:: impl_:: deprecations:: inspect_type( #self_arg & #self_checker) ,
689- #( #args) , *
690- )
691- }
692- }
703+ ( quote ! ( __self, ) , quote ! { let ( __self, ) = ( #self_arg) ; } )
704+ } ;
705+ quote ! { {
706+ #self_arg_decl
707+ #( let #arg_names = #args; ) *
708+ py. allow_threads( || function( #self_arg_name #( #arg_names) , * ) )
709+ } }
710+ } else {
711+ let self_arg = self_arg ( ) ;
712+ quote ! ( function( #self_arg #( #args) , * ) )
693713 } ;
694714 quotes:: map_result_into_ptr ( quotes:: ok_wrap ( call, ctx) , ctx)
695715 } ;
0 commit comments