Skip to content

Commit

Permalink
Simplify code generation for ArgInfo (#62)
Browse files Browse the repository at this point in the history
- Cleanup and simplify `impl ToToken for ArgInfo` which based on
`pyo3/experimental-inspect` feature as before #37
- This can yield breaking changes. Bump up to 0.6.0.
  • Loading branch information
termoshtt authored Aug 22, 2024
1 parent 53f50d2 commit cea14d1
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 56 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ members = [
resolver = "2"

[workspace.package]
version = "0.5.4"
version = "0.6.0"
edition = "2021"

description = "Stub file (*.pyi) generator for PyO3"
Expand Down
3 changes: 3 additions & 0 deletions examples/pure/pure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class A:
def show_x(self) -> None:
...

def ref_test(self, x:dict) -> dict:
...


class Number(Enum):
FLOAT = auto()
Expand Down
6 changes: 5 additions & 1 deletion examples/pure/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
mod readme {}

use ahash::RandomState;
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::*};
use pyo3_stub_gen::{create_exception, define_stub_info_gatherer, derive::*};
use std::{collections::HashMap, path::PathBuf};

Expand Down Expand Up @@ -52,6 +52,10 @@ impl A {
fn show_x(&self) {
println!("x = {}", self.x);
}

fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> {
x
}
}

#[gen_stub_pyfunction]
Expand Down
57 changes: 9 additions & 48 deletions pyo3-stub-gen-derive/src/gen_stub/arg.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::remove_lifetime;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{
spanned::Spanned, FnArg, GenericArgument, PatType, PathArguments, Result, Type, TypePath,
TypeReference,
};

pub fn parse_args(iter: impl IntoIterator<Item = FnArg>) -> Result<Vec<ArgInfo>> {
Expand All @@ -20,8 +20,8 @@ pub fn parse_args(iter: impl IntoIterator<Item = FnArg>) -> Result<Vec<ArgInfo>>
// Regard the first argument with `PyRef<'_, Self>` and `PyMutRef<'_, Self>` types as a receiver.
if n == 0 && (last.ident == "PyRef" || last.ident == "PyRefMut") {
if let PathArguments::AngleBracketed(inner) = &last.arguments {
assert!(inner.args.len() == 2);
if let GenericArgument::Type(Type::Path(TypePath { path, .. })) = &inner.args[1]
if let GenericArgument::Type(Type::Path(TypePath { path, .. })) =
&inner.args[inner.args.len() - 1]
{
let last = path.segments.last().unwrap();
if last.ident == "Self" {
Expand Down Expand Up @@ -57,55 +57,16 @@ impl TryFrom<FnArg> for ArgInfo {
}
}

fn type_to_token(ty: &Type) -> TokenStream2 {
match ty {
Type::Path(TypePath { path, .. }) => {
if let Some(last_seg) = path.segments.last() {
// `CompareOp` is an enum for `__richcmp__`
// PyO3 reference: https://docs.rs/pyo3/latest/pyo3/pyclass/enum.CompareOp.html
// PEP: https://peps.python.org/pep-0207/
if last_seg.ident == "CompareOp" {
quote! { ::pyo3_stub_gen::type_info::compare_op_type_input }
} else {
quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input }
}
} else {
unreachable!("Empty path segment: {:?}", path);
}
}
Type::Reference(TypeReference { elem, .. }) => {
match elem.as_ref() {
Type::Path(TypePath { path, .. }) => {
if let Some(last) = path.segments.last() {
match last.ident.to_string().as_str() {
// Types where `&T: ::pyo3_stub_gen::PyStubType` instead of `T: ::pyo3_stub_gen::PyStubType`
// i.e. `&str` and most of `Py*` types defined in PyO3.
"str" | "PyAny" | "PyString" | "PyDict" => {
return quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input };
}
_ => {}
}
}
}
Type::Slice(_) => {
return quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input };
}
_ => {}
}
type_to_token(elem)
}
_ => {
quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input }
}
}
}

impl ToTokens for ArgInfo {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let Self { name, r#type: ty } = self;
let type_tt = type_to_token(ty);
let mut ty = ty.clone();
remove_lifetime(&mut ty);
tokens.append_all(quote! {
::pyo3_stub_gen::type_info::ArgInfo { name: #name, r#type: #type_tt }
::pyo3_stub_gen::type_info::ArgInfo {
name: #name,
r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input
}
});
}
}
8 changes: 8 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ pub fn remove_lifetime(ty: &mut Type) {
rty.lifetime = None;
remove_lifetime(rty.elem.as_mut());
}
Type::Tuple(ty) => {
for elem in &mut ty.elems {
remove_lifetime(elem);
}
}
Type::Array(ary) => {
remove_lifetime(ary.elem.as_mut());
}
_ => {}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-stub-gen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde.workspace = true
toml.workspace = true

[dependencies.pyo3-stub-gen-derive]
version = "0.5.4"
version = "0.6.0"
path = "../pyo3-stub-gen-derive"

[features]
Expand Down
9 changes: 9 additions & 0 deletions pyo3-stub-gen/src/stub_type/collections.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use crate::stub_type::*;
use std::collections::{BTreeMap, BTreeSet, HashMap};

impl<T: PyStubType> PyStubType for &T {
fn type_input() -> TypeInfo {
T::type_input()
}
fn type_output() -> TypeInfo {
T::type_output()
}
}

impl<T: PyStubType> PyStubType for Option<T> {
fn type_input() -> TypeInfo {
let TypeInfo { name, mut import } = T::type_input();
Expand Down
2 changes: 2 additions & 0 deletions pyo3-stub-gen/src/stub_type/pyo3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::stub_type::*;
use ::pyo3::{
basic::CompareOp,
pybacked::{PyBackedBytes, PyBackedStr},
pyclass::boolean_struct::False,
types::*,
Expand Down Expand Up @@ -78,3 +79,4 @@ impl_builtin!(PyByteArray, "bytearray");
impl_builtin!(PyBytes, "bytes");
impl_builtin!(PyBackedBytes, "bytes");
impl_builtin!(PyType, "type");
impl_builtin!(CompareOp, "int");

0 comments on commit cea14d1

Please sign in to comment.