Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortcut macros for collections #1703

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
73 changes: 73 additions & 0 deletions pyo3-macros-backend/src/dict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use std::iter::FromIterator;
use syn::parse::{Parse, ParseBuffer, ParseStream};
use syn::punctuated::Punctuated;
use syn::Token;
use syn::{braced, Expr};

#[derive(Debug)]
pub struct PyDictLiteral {
pub py: Ident,
pub items: Vec<KeyValue>,
}

#[derive(Debug)]
pub struct KeyValue {
key: syn::Expr,
value: syn::Expr,
}

#[derive(Debug)]
struct Key(syn::Expr);

impl Parse for Key {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut tokens = vec![];

while !input.peek(Token![:]) || input.peek(Token![::]) {
let tt = input.parse::<TokenTree>()?;
tokens.push(tt);
}
let stream = TokenStream::from_iter(tokens.into_iter());

let expr = syn::parse2::<Expr>(stream)?;
Ok(Self(expr))
}
}

impl Parse for KeyValue {
fn parse(input: ParseStream) -> syn::Result<Self> {
let key: Key = input.parse()?;
let _sep: Token![:] = input.parse()?;
let value: syn::Expr = input.parse()?;

Ok(Self { key: key.0, value })
}
}

impl Parse for PyDictLiteral {
fn parse(input: ParseStream) -> syn::Result<Self> {
let py: Ident = input.parse()?;
let _arrow: Token![=>] = input.parse()?;

let body: ParseBuffer;
braced!(body in input);

let items: Punctuated<KeyValue, Token![,]> = Punctuated::parse_terminated(&body)?;

Ok(Self {
py,
items: items.into_iter().collect(),
})
}
}

impl ToTokens for KeyValue {
fn to_tokens(&self, tokens: &mut TokenStream) {
let key = &self.key;
let value = &self.value;
let ts = quote! {(#key, #value)};
tokens.extend(ts);
}
}
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod utils;
mod attributes;
mod defs;
mod deprecations;
mod dict;
mod from_pyobject;
mod konst;
mod method;
Expand All @@ -23,6 +24,7 @@ mod pyimpl;
mod pymethod;
mod pyproto;

pub use dict::PyDictLiteral;
pub use from_pyobject::build_derive_from_pyobject;
pub use module::{process_functions_in_module, py_init, PyModuleOptions};
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
Expand Down
17 changes: 16 additions & 1 deletion pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use proc_macro::TokenStream;
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyClassMethodsType,
PyFunctionOptions, PyModuleOptions,
PyDictLiteral, PyFunctionOptions, PyModuleOptions,
};
use quote::quote;
use syn::{parse::Nothing, parse_macro_input};
Expand Down Expand Up @@ -199,6 +199,21 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
.into()
}

#[proc_macro]
pub fn py_dict(input: proc_macro::TokenStream) -> TokenStream {
let PyDictLiteral { items, py } = parse_macro_input!(input as PyDictLiteral);
let stream = quote! {
(|| {
use pyo3 as _pyo3;

let dict = _pyo3::types::PyDict::new(#py);
#(dict.set_item#items?;)*
_pyo3::PyResult::Ok(dict)
})()
};
stream.into()
}

fn pyclass_impl(
attrs: TokenStream,
mut ast: syn::ItemStruct,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ pub mod proc_macro {
}

#[cfg(feature = "macros")]
pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject};
pub use pyo3_macros::{py_dict, pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject};

#[macro_use]
mod macros;
Expand Down
144 changes: 144 additions & 0 deletions src/types/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#[cfg(feature = "macros")]
pub use pyo3_macros::py_dict;

#[doc(hidden)]
#[macro_export]
macro_rules! py_object_vec {
($py:ident, [$($item:expr),+]) => {{
let items_vec: Vec<$crate::PyObject> =
vec![$($crate::conversion::IntoPy::into_py($item, $py)),+];
items_vec
}};
}

#[macro_export]
macro_rules! py_list {
($py:ident, [$($items:expr),+]) => {{
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::list::PyList::new($py, items_vec)
}};
}

#[macro_export]
macro_rules! py_tuple {
($py:ident, ($($items:expr),+)) => {{
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::PyTuple::new($py, items_vec)
}};
}

#[macro_export]
macro_rules! py_set {
($py:ident, {$($items:expr),+}) => {{
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::set::PySet::new($py, items_vec.as_slice())
}};
}

#[macro_export]
macro_rules! py_frozenset {
($py:ident, {$($items:expr),+}) => {{
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::set::PyFrozenSet::new($py, items_vec.as_slice())
}};
}

#[cfg(test)]
mod test {
use crate::types::PyFrozenSet;
use crate::Python;

#[test]
fn test_list_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let single_item_list = py_list!(py, ["elem"]);
assert_eq!(
"elem",
single_item_list
.get_item(0)
.expect("failed to get item")
.extract::<&str>()
.unwrap()
);

let multi_item_list = py_list!(py, ["elem1", "elem2", 3, 4]);

assert_eq!(
"['elem1', 'elem2', 3, 4]",
multi_item_list.str().unwrap().extract::<&str>().unwrap()
);
}

#[test]
fn test_tuple_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let single_item_tuple = py_tuple!(py, ("elem"));
assert_eq!(
"elem",
single_item_tuple
.get_item(0)
.expect("failed to get item")
.extract::<&str>()
.unwrap()
);

let multi_item_tuple = py_tuple!(py, ("elem1", "elem2", 3, 4));

assert_eq!(
"('elem1', 'elem2', 3, 4)",
multi_item_tuple.str().unwrap().extract::<&str>().unwrap()
);
}

#[test]
fn test_set_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let set = py_set!(py, { "set_elem" }).expect("failed to create set");

assert!(set.contains("set_elem").unwrap());

set.call_method1(
"update",
py_tuple!(
py,
(py_set!(py, {"new_elem1", "new_elem2", "set_elem"}).unwrap())
),
)
.expect("failed to update set");

for &expected_elem in &["set_elem", "new_elem1", "new_elem2"] {
assert!(set.contains(expected_elem).unwrap());
}
}

#[test]
fn test_frozenset_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let frozenset = py_frozenset!(py, { "set_elem" }).expect("failed to create frozenset");

assert!(frozenset.contains("set_elem").unwrap());

let intersection = frozenset
.call_method1(
"intersection",
py_tuple!(
py,
(py_set!(py, {"new_elem1", "new_elem2", "set_elem"}).unwrap())
),
)
.expect("failed to call intersection()")
.downcast::<PyFrozenSet>()
.expect("failed to downcast to FrozenSet");

assert_eq!(1, intersection.len());
assert!(intersection.contains("set_elem").unwrap());
}
}
1 change: 1 addition & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ mod floatob;
mod function;
mod iterator;
mod list;
mod macros;
mod mapping;
mod module;
mod num;
Expand Down
52 changes: 52 additions & 0 deletions tests/test_literals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#![cfg(feature = "macros")]

use pyo3::prelude::*;
use pyo3::{py_dict, py_run, py_tuple};

#[test]
fn test_dict_literal() {
let gil = Python::acquire_gil();
let py = gil.python();

let dict = py_dict!(py => {"key": "value"}).expect("failed to create dict");
assert_eq!(
"value",
dict.get_item("key").unwrap().extract::<String>().unwrap()
);

let value = "value";
let multi_elem_dict =
py_dict!(py => {"key1": value, 143: "abcde"}).expect("failed to create dict");
assert_eq!(
"value",
multi_elem_dict
.get_item("key1")
.unwrap()
.extract::<&str>()
.unwrap()
);
assert_eq!(
"abcde",
multi_elem_dict
.get_item(143)
.unwrap()
.extract::<&str>()
.unwrap()
);

let keys = &["key1", "key2"];

let expr_dict = py_dict!(py => {
keys[0]: "value1",
keys[1]: "value2",
3-7: py_tuple!(py, ("elem1", "elem2", 3)),
"KeY".to_lowercase(): 100 * 2,
})
.expect("failed to create dict");

py_run!(
py,
expr_dict,
"assert expr_dict == {'key1': 'value1', 'key2': 'value2', -4: ('elem1', 'elem2', 3), 'key': 200}"
);
}