diff --git a/Cargo.lock b/Cargo.lock index 2d36ecf..befd21a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,7 +301,7 @@ dependencies = [ [[package]] name = "mixed" -version = "0.5.4" +version = "0.6.0" dependencies = [ "env_logger", "pyo3", @@ -310,7 +310,7 @@ dependencies = [ [[package]] name = "mixed_sub" -version = "0.5.4" +version = "0.6.0" dependencies = [ "env_logger", "pyo3", @@ -428,7 +428,7 @@ dependencies = [ [[package]] name = "pure" -version = "0.5.4" +version = "0.6.0" dependencies = [ "ahash", "env_logger", @@ -501,7 +501,7 @@ dependencies = [ [[package]] name = "pyo3-stub-gen" -version = "0.5.4" +version = "0.6.0" dependencies = [ "anyhow", "inventory", @@ -518,7 +518,7 @@ dependencies = [ [[package]] name = "pyo3-stub-gen-derive" -version = "0.5.4" +version = "0.6.0" dependencies = [ "insta", "inventory", diff --git a/Cargo.toml b/Cargo.toml index b118718..dd26897 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.5.4" +version = "0.6.0" edition = "2021" description = "Stub file (*.pyi) generator for PyO3" diff --git a/examples/pure/pure.pyi b/examples/pure/pure.pyi index a373091..7f2f309 100644 --- a/examples/pure/pure.pyi +++ b/examples/pure/pure.pyi @@ -12,6 +12,9 @@ class A: def show_x(self) -> None: ... + def ref_test(self, x:dict) -> dict: + ... + class Number(Enum): FLOAT = auto() diff --git a/examples/pure/src/lib.rs b/examples/pure/src/lib.rs index 0395fdb..6061169 100644 --- a/examples/pure/src/lib.rs +++ b/examples/pure/src/lib.rs @@ -2,7 +2,7 @@ mod readme {} use ahash::RandomState; -use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, types::*}; use pyo3_stub_gen::{create_exception, define_stub_info_gatherer, derive::*}; use std::{collections::HashMap, path::PathBuf}; @@ -52,6 +52,10 @@ impl A { fn show_x(&self) { println!("x = {}", self.x); } + + fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> { + x + } } #[gen_stub_pyfunction] diff --git a/pyo3-stub-gen-derive/src/gen_stub/arg.rs b/pyo3-stub-gen-derive/src/gen_stub/arg.rs index 9157640..652993d 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/arg.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/arg.rs @@ -1,8 +1,8 @@ +use super::remove_lifetime; use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ spanned::Spanned, FnArg, GenericArgument, PatType, PathArguments, Result, Type, TypePath, - TypeReference, }; pub fn parse_args(iter: impl IntoIterator) -> Result> { @@ -20,8 +20,8 @@ pub fn parse_args(iter: impl IntoIterator) -> Result> // Regard the first argument with `PyRef<'_, Self>` and `PyMutRef<'_, Self>` types as a receiver. if n == 0 && (last.ident == "PyRef" || last.ident == "PyRefMut") { if let PathArguments::AngleBracketed(inner) = &last.arguments { - assert!(inner.args.len() == 2); - if let GenericArgument::Type(Type::Path(TypePath { path, .. })) = &inner.args[1] + if let GenericArgument::Type(Type::Path(TypePath { path, .. })) = + &inner.args[inner.args.len() - 1] { let last = path.segments.last().unwrap(); if last.ident == "Self" { @@ -57,55 +57,16 @@ impl TryFrom for ArgInfo { } } -fn type_to_token(ty: &Type) -> TokenStream2 { - match ty { - Type::Path(TypePath { path, .. }) => { - if let Some(last_seg) = path.segments.last() { - // `CompareOp` is an enum for `__richcmp__` - // PyO3 reference: https://docs.rs/pyo3/latest/pyo3/pyclass/enum.CompareOp.html - // PEP: https://peps.python.org/pep-0207/ - if last_seg.ident == "CompareOp" { - quote! { ::pyo3_stub_gen::type_info::compare_op_type_input } - } else { - quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input } - } - } else { - unreachable!("Empty path segment: {:?}", path); - } - } - Type::Reference(TypeReference { elem, .. }) => { - match elem.as_ref() { - Type::Path(TypePath { path, .. }) => { - if let Some(last) = path.segments.last() { - match last.ident.to_string().as_str() { - // Types where `&T: ::pyo3_stub_gen::PyStubType` instead of `T: ::pyo3_stub_gen::PyStubType` - // i.e. `&str` and most of `Py*` types defined in PyO3. - "str" | "PyAny" | "PyString" | "PyDict" => { - return quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input }; - } - _ => {} - } - } - } - Type::Slice(_) => { - return quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input }; - } - _ => {} - } - type_to_token(elem) - } - _ => { - quote! { <#ty as ::pyo3_stub_gen::PyStubType>::type_input } - } - } -} - impl ToTokens for ArgInfo { fn to_tokens(&self, tokens: &mut TokenStream2) { let Self { name, r#type: ty } = self; - let type_tt = type_to_token(ty); + let mut ty = ty.clone(); + remove_lifetime(&mut ty); tokens.append_all(quote! { - ::pyo3_stub_gen::type_info::ArgInfo { name: #name, r#type: #type_tt } + ::pyo3_stub_gen::type_info::ArgInfo { + name: #name, + r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input + } }); } } diff --git a/pyo3-stub-gen-derive/src/gen_stub/util.rs b/pyo3-stub-gen-derive/src/gen_stub/util.rs index 50aa9e3..a3e1107 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/util.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/util.rs @@ -36,6 +36,14 @@ pub fn remove_lifetime(ty: &mut Type) { rty.lifetime = None; remove_lifetime(rty.elem.as_mut()); } + Type::Tuple(ty) => { + for elem in &mut ty.elems { + remove_lifetime(elem); + } + } + Type::Array(ary) => { + remove_lifetime(ary.elem.as_mut()); + } _ => {} } } diff --git a/pyo3-stub-gen/Cargo.toml b/pyo3-stub-gen/Cargo.toml index 6faad41..e1c77b7 100644 --- a/pyo3-stub-gen/Cargo.toml +++ b/pyo3-stub-gen/Cargo.toml @@ -21,7 +21,7 @@ serde.workspace = true toml.workspace = true [dependencies.pyo3-stub-gen-derive] -version = "0.5.4" +version = "0.6.0" path = "../pyo3-stub-gen-derive" [features] diff --git a/pyo3-stub-gen/src/stub_type/collections.rs b/pyo3-stub-gen/src/stub_type/collections.rs index ec51b26..209025e 100644 --- a/pyo3-stub-gen/src/stub_type/collections.rs +++ b/pyo3-stub-gen/src/stub_type/collections.rs @@ -1,6 +1,15 @@ use crate::stub_type::*; use std::collections::{BTreeMap, BTreeSet, HashMap}; +impl PyStubType for &T { + fn type_input() -> TypeInfo { + T::type_input() + } + fn type_output() -> TypeInfo { + T::type_output() + } +} + impl PyStubType for Option { fn type_input() -> TypeInfo { let TypeInfo { name, mut import } = T::type_input(); diff --git a/pyo3-stub-gen/src/stub_type/pyo3.rs b/pyo3-stub-gen/src/stub_type/pyo3.rs index a5fd7b0..ead4541 100644 --- a/pyo3-stub-gen/src/stub_type/pyo3.rs +++ b/pyo3-stub-gen/src/stub_type/pyo3.rs @@ -1,5 +1,6 @@ use crate::stub_type::*; use ::pyo3::{ + basic::CompareOp, pybacked::{PyBackedBytes, PyBackedStr}, pyclass::boolean_struct::False, types::*, @@ -78,3 +79,4 @@ impl_builtin!(PyByteArray, "bytearray"); impl_builtin!(PyBytes, "bytes"); impl_builtin!(PyBackedBytes, "bytes"); impl_builtin!(PyType, "type"); +impl_builtin!(CompareOp, "int");