Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 127 additions & 37 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
use crate::{
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::{impl_arg_params, Holders},
pyfunction::{
Expand All @@ -17,19 +17,107 @@ use crate::{
};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
pub struct RegularArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
pub optional: Option<&'a syn::Type>,
pub default: Option<syn::Expr>,
pub py: bool,
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
pub default_value: Option<syn::Expr>,
pub option_wrapped_type: Option<&'a syn::Type>,
}

#[derive(Clone, Debug)]
pub struct VarArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub struct KwArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub struct CancelHandleArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub struct PyArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub enum FnArg<'a> {
Regular(RegularArg<'a>),
VarArgs(VarArg<'a>),
KwArgs(KwArg<'a>),
Py(PyArg<'a>),
CancelHandle(CancelHandleArg<'a>),
}

impl<'a> FnArg<'a> {
pub fn name(&self) -> &'a syn::Ident {
match self {
FnArg::Regular(RegularArg { name, .. }) => name,
FnArg::VarArgs(VarArg { name, .. }) => name,
FnArg::KwArgs(KwArg { name, .. }) => name,
FnArg::Py(PyArg { name, .. }) => name,
FnArg::CancelHandle(CancelHandleArg { name, .. }) => name,
}
}

pub fn ty(&self) -> &'a syn::Type {
match self {
FnArg::Regular(RegularArg { ty, .. }) => ty,
FnArg::VarArgs(VarArg { ty, .. }) => ty,
FnArg::KwArgs(KwArg { ty, .. }) => ty,
FnArg::Py(PyArg { ty, .. }) => ty,
FnArg::CancelHandle(CancelHandleArg { ty, .. }) => ty,
}
}

#[allow(clippy::wrong_self_convention)]
pub fn from_py_with(&self) -> Option<&FromPyWithAttribute> {
if let FnArg::Regular(RegularArg { attrs, .. }) = self {
attrs.from_py_with.as_ref()
} else {
None
}
}

pub fn to_varargs_mut(&mut self) -> Result<&mut Self> {
if let Self::Regular(RegularArg {
name,
ty,
option_wrapped_type: None,
..
}) = self
{
*self = Self::VarArgs(VarArg { name, ty });
Ok(self)
} else {
bail_spanned!(self.name().span() => "args cannot be optional")
}
}

pub fn to_kwargs_mut(&mut self) -> Result<&mut Self> {
if let Self::Regular(RegularArg {
name,
ty,
option_wrapped_type: Some(..),
..
}) = self
{
*self = Self::KwArgs(KwArg { name, ty });
Ok(self)
} else {
bail_spanned!(self.name().span() => "kwargs must be Option<_>")
}
}

/// Transforms a rust fn arg parsed with syn into a method::FnArg
pub fn parse(arg: &'a mut syn::FnArg) -> Result<Self> {
match arg {
Expand All @@ -47,26 +135,30 @@ impl<'a> FnArg<'a> {
other => return Err(handle_argument_error(other)),
};

let is_cancel_handle = arg_attrs.cancel_handle.is_some();
if utils::is_python(&cap.ty) {
return Ok(Self::Py(PyArg {
name: ident,
ty: &cap.ty,
}));
}

Ok(FnArg {
if arg_attrs.cancel_handle.is_some() {
return Ok(Self::CancelHandle(CancelHandleArg {
name: ident,
ty: &cap.ty,
}));
}

Ok(Self::Regular(RegularArg {
name: ident,
ty: &cap.ty,
optional: utils::option_type_argument(&cap.ty),
default: None,
py: utils::is_python(&cap.ty),
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle,
})
default_value: None,
option_wrapped_type: utils::option_type_argument(&cap.ty),
}))
}
}
}

pub fn is_regular(&self) -> bool {
!self.py && !self.is_cancel_handle && !self.is_kwargs && !self.is_varargs
}
}

fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
Expand Down Expand Up @@ -492,12 +584,14 @@ impl<'a> FnSpec<'a> {
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
.filter(|arg| matches!(arg, FnArg::CancelHandle(..)));
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
if let Some(FnArg::CancelHandle(CancelHandleArg { name, .. })) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(FnArg::CancelHandle(CancelHandleArg { name, .. })) =
cancel_handle_iter.next()
{
bail_spanned!(name.span() => "`cancel_handle` may only be specified once");
}
}

Expand Down Expand Up @@ -616,14 +710,10 @@ impl<'a> FnSpec<'a> {
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
.map(|arg| match arg {
FnArg::Py(..) => quote!(py),
FnArg::CancelHandle(..) => quote!(__cancel_handle),
_ => unreachable!("`CallingConvention::Noargs` should not contain any arguments (reaching Python) except for `self`, which is handled below."),
})
.collect();
let call = rust_call(args, &mut holders);
Expand All @@ -646,7 +736,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::Fastcall => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx);
let call = rust_call(args, &mut holders);
let init_holders = holders.init_holders(ctx);
let check_gil_refs = holders.check_gil_refs();
Expand All @@ -671,7 +761,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::Varargs => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx);
let call = rust_call(args, &mut holders);
let init_holders = holders.init_holders(ctx);
let check_gil_refs = holders.check_gil_refs();
Expand All @@ -695,7 +785,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx);
let self_arg = self
.tp
.self_arg(cls, ExtractErrorMode::Raise, &mut holders, ctx);
Expand Down
Loading