Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
csernazs committed Jul 16, 2024
1 parent 56f70a2 commit 7eba98c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 58 deletions.
37 changes: 1 addition & 36 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
get_doc,
pyclass::PyClassPyO3Option,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::{Ctx, LitCStr, PyO3CratePath},
utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
};
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down Expand Up @@ -565,11 +565,6 @@ fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bo
found
}

enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}

impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
fn eq(&self, other: &syn::Ident) -> bool {
match self {
Expand All @@ -578,36 +573,6 @@ impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
}
}
}
fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
has_attribute_with_namespace(attrs, None, &[ident])
}

fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}

attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}

fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
Expand Down
31 changes: 19 additions & 12 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;

use crate::utils::Ctx;
use crate::utils::{has_attribute, has_attribute_with_namespace, Ctx, PyO3CratePath};
use crate::{
attributes::{take_pyo3_options, CrateAttribute},
konst::{ConstAttributes, ConstSpec},
Expand Down Expand Up @@ -85,16 +85,23 @@ pub fn build_py_methods(
}
}

fn check_pyfunction(meth: &mut ImplItemFn) -> syn::Result<()> {
if meth
.attrs
.iter()
.any(|attr| attr.path().is_ident("pyfunction"))
{
meth.attrs.clear();
bail_spanned!(meth.span() => "functions inside #[pymethods] do not need to be annotated with #[pyfunction]");
}
Ok(())
fn check_pyfunction(pyo3_path: &PyO3CratePath, meth: &mut ImplItemFn) -> syn::Result<()> {
let mut error = None;

meth.attrs.retain(|attr| {
let attrs = [attr.clone()];

if has_attribute(&attrs, "pyfunction")
|| has_attribute_with_namespace(&attrs, Some(pyo3_path), &["pyfunction"])
|| has_attribute_with_namespace(&attrs, Some(pyo3_path), &["prelude", "pyfunction"]) {
error = Some(err_spanned!(meth.sig.span() => "functions inside #[pymethods] do not need to be annotated with #[pyfunction]"));
false
} else {
true
}
});

error.map_or(Ok(()), Err)
}

pub fn impl_methods(
Expand All @@ -117,7 +124,7 @@ pub fn impl_methods(
let mut fun_options = PyFunctionOptions::from_attrs(&mut meth.attrs)?;
fun_options.krate = fun_options.krate.or_else(|| options.krate.clone());

check_pyfunction(meth)?;
check_pyfunction(&ctx.pyo3_path, meth)?;

match pymethod::gen_py_method(ty, &mut meth.sig, &mut meth.attrs, fun_options, ctx)?
{
Expand Down
36 changes: 36 additions & 0 deletions pyo3-macros-backend/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,39 @@ pub fn apply_renaming_rule(rule: RenamingRule, name: &str) -> String {
pub(crate) fn is_abi3() -> bool {
pyo3_build_config::get().abi3
}

pub(crate) enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}

pub(crate) fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
has_attribute_with_namespace(attrs, None, &[ident])
}

pub(crate) fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}

attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}
4 changes: 2 additions & 2 deletions tests/ui/invalid_pyfunction_definition.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


#[pyo3::pymodule]
mod pyo3_scratch {
use pyo3::prelude::*;
Expand All @@ -13,3 +11,5 @@ mod pyo3_scratch {
fn bug() {}
}
}

fn main() {}
10 changes: 2 additions & 8 deletions tests/ui/invalid_pyfunction_definition.stderr
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
error: functions inside #[pymethods] do not need to be annotated with #[pyfunction]
--> tests/ui/invalid_pyfunction_definition.rs:13:9
--> tests/ui/invalid_pyfunction_definition.rs:11:9
|
13 | fn bug() {}
11 | fn bug() {}
| ^^

error[E0601]: `main` function not found in crate `$CRATE`
--> tests/ui/invalid_pyfunction_definition.rs:15:2
|
15 | }
| ^ consider adding a `main` function to `$DIR/tests/ui/invalid_pyfunction_definition.rs`

0 comments on commit 7eba98c

Please sign in to comment.