Skip to content

Commit

Permalink
allow expressions in py_dict!
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-konovalenko committed Dec 18, 2021
1 parent 7166882 commit 85b8dbc
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 115 deletions.
78 changes: 78 additions & 0 deletions pyo3-macros-backend/src/dict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use std::iter::FromIterator;
use syn::parse::{Parse, ParseBuffer, ParseStream};
use syn::punctuated::Punctuated;
use syn::Token;
use syn::{braced, Expr};

#[derive(Debug)]
pub struct PyDictLiteral {
pub py: Ident,
pub items: Vec<KeyValue>,
}

#[derive(Debug)]
pub struct KeyValue {
key: syn::Expr,
sep: Token![:],
value: syn::Expr,
}

#[derive(Debug)]
struct Key(syn::Expr);

impl Parse for Key {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut tokens = vec![];

while !input.peek(Token![:]) || input.peek(Token![::]) {
let tt = input.parse::<TokenTree>()?;
tokens.push(tt);
}
let stream = TokenStream::from_iter(tokens.into_iter());

let expr = syn::parse2::<Expr>(stream)?;
Ok(Self(expr))
}
}

impl Parse for KeyValue {
fn parse(input: ParseStream) -> syn::Result<Self> {
let key: Key = input.parse()?;
let sep: Token![:] = input.parse()?;
let value: syn::Expr = input.parse()?;

Ok(Self {
key: key.0,
sep,
value,
})
}
}

impl Parse for PyDictLiteral {
fn parse(input: ParseStream) -> syn::Result<Self> {
let py: Ident = input.parse()?;
let _arrow: Token![=>] = input.parse()?;

let body: ParseBuffer;
braced!(body in input);

let items: Punctuated<KeyValue, Token![,]> = Punctuated::parse_terminated(&body)?;

Ok(Self {
py,
items: items.into_iter().collect(),
})
}
}

impl ToTokens for KeyValue {
fn to_tokens(&self, tokens: &mut TokenStream) {
let key = &self.key;
let value = &self.value;
let ts = quote! {(#key, #value)};
tokens.extend(ts);
}
}
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod utils;
mod attributes;
mod defs;
mod deprecations;
mod dict;
mod from_pyobject;
mod konst;
mod method;
Expand All @@ -23,6 +24,7 @@ mod pyimpl;
mod pymethod;
mod pyproto;

pub use dict::PyDictLiteral;
pub use from_pyobject::build_derive_from_pyobject;
pub use module::{process_functions_in_module, py_init, PyModuleOptions};
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
Expand Down
15 changes: 14 additions & 1 deletion pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use proc_macro::TokenStream;
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyClassMethodsType,
PyFunctionOptions, PyModuleOptions,
PyDictLiteral, PyFunctionOptions, PyModuleOptions,
};
use quote::quote;
use syn::{parse::Nothing, parse_macro_input};
Expand Down Expand Up @@ -199,6 +199,19 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
.into()
}

#[proc_macro]
pub fn py_dict(input: proc_macro::TokenStream) -> TokenStream {
let PyDictLiteral { items, py } = parse_macro_input!(input as PyDictLiteral);
let stream = quote! {
(|| {
let dict = ::pyo3::types::PyDict::new(#py);
#(dict.set_item#items?;)*
::pyo3::prelude::PyResult::Ok(dict)
})()
};
stream.into()
}

fn pyclass_impl(
attrs: TokenStream,
mut ast: syn::ItemStruct,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ pub mod proc_macro {
}

#[cfg(feature = "macros")]
pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject};
pub use pyo3_macros::{py_dict, pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject};

#[macro_use]
mod macros;
Expand Down
46 changes: 0 additions & 46 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,6 @@ where
Ok(ret)
}
}
macro_rules! py_dict {
($py:ident, {$key:literal : $value:expr}) => {
[($key, $value)].into_py_dict($py)
};

($py:ident, {$key:literal : $value:expr, $($keys:literal : $values:expr),+}) => {{
let dct = py_dict!($py, {$($keys : $values),+});
dct.set_item($key, $value).expect("failed to set item on dict");
dct
}};
}

#[cfg(test)]
mod tests {
Expand All @@ -393,41 +382,6 @@ mod tests {
use crate::{PyTryFrom, ToPyObject};
use std::collections::{BTreeMap, HashMap};

#[test]
fn test_dict_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let single_elem_dict = py_dict!(py, { "a": 2 });
assert_eq!(
2,
single_elem_dict
.get_item("a")
.unwrap()
.extract::<i32>()
.unwrap()
);

let value = "value";
let multi_elem_dict = py_dict!(py, {"key1": value, 143: "abcde"});
assert_eq!(
"value",
multi_elem_dict
.get_item("key1")
.unwrap()
.extract::<&str>()
.unwrap()
);
assert_eq!(
"abcde",
multi_elem_dict
.get_item(143)
.unwrap()
.extract::<&str>()
.unwrap()
);
}

#[test]
fn test_new() {
Python::with_gil(|py| {
Expand Down
88 changes: 21 additions & 67 deletions src/types/macros.rs
Original file line number Diff line number Diff line change
@@ -1,95 +1,53 @@
pub use pyo3_macros::py_dict;

#[doc(hidden)]
#[macro_export]
macro_rules! py_object_vec {
($py:ident, [$($item:expr),+]) => {{
let items_vec: Vec<$crate::instance::PyObject> =
let items_vec: Vec<$crate::PyObject> =
vec![$($crate::conversion::IntoPy::into_py($item, $py)),+];
items_vec
}};
}

#[macro_export]
macro_rules! py_dict {
($py:ident, {$($keys:literal : $values:expr),+}) => {{
let items: $crate::instance::PyObject = py_list!($py, [$(($keys, $values)),+]).into();

$crate::types::dict::PyDict::from_sequence($py, items)
}};
}

#[macro_export]
macro_rules! py_list {
($py:ident, [$($items:expr),+]) => {{
let items_vec = py_object_vec!($py, [$($items),+]);
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::list::PyList::new($py, items_vec)
}};
}

#[macro_export]
macro_rules! py_tuple {
($py:ident, ($($items:expr),+)) => {{
let items_vec = py_object_vec!($py, [$($items),+]);
$crate::types::tuple::PyTuple::new($py, items_vec)
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::PyTuple::new($py, items_vec)
}};
}

#[macro_export]
macro_rules! py_set {
($py:ident, {$($items:expr),+}) => {{
let items_vec = py_object_vec!($py, [$($items),+]);
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::set::PySet::new($py, items_vec.as_slice())
}};
}

#[macro_export]
macro_rules! py_frozenset {
($py:ident, {$($items:expr),+}) => {{
let items_vec = py_object_vec!($py, [$($items),+]);
let items_vec = $crate::py_object_vec!($py, [$($items),+]);
$crate::types::set::PyFrozenSet::new($py, items_vec.as_slice())
}};
}

#[cfg(test)]
mod test {
use crate::py_dict;
use crate::types::PyFrozenSet;
use crate::Python;

#[test]
fn test_dict_macro() {
let gil = Python::acquire_gil();
let py = gil.python();

let single_elem_dict = py_dict!(py, { "a": 2 }).expect("failed to create dict");
assert_eq!(
2,
single_elem_dict
.get_item("a")
.unwrap()
.extract::<i32>()
.unwrap()
);

let value = "value";
let multi_elem_dict = py_dict!(py, {"key1": value, 143: "abcde", "name": "Даня"})
.expect("failed to create dict");
assert_eq!(
"value",
multi_elem_dict
.get_item("key1")
.unwrap()
.extract::<&str>()
.unwrap()
);
assert_eq!(
"abcde",
multi_elem_dict
.get_item(143)
.unwrap()
.extract::<&str>()
.unwrap()
);
}

#[test]
fn test_list_macro() {
let gil = Python::acquire_gil();
Expand All @@ -98,18 +56,16 @@ mod test {
let single_item_list = py_list!(py, ["elem"]);
assert_eq!(
"elem",
single_item_list.get_item(0).extract::<&str>().unwrap()
single_item_list
.get_item(0)
.expect("failed to get item")
.extract::<&str>()
.unwrap()
);

let multi_item_list = py_list!(
py,
[
"elem1",
"elem2",
3,
4,
py_dict!(py, {"type": "user"}).unwrap()
]
["elem1", "elem2", 3, 4, py_dict!({"type": "user"}).unwrap()]
);

assert_eq!(
Expand All @@ -126,18 +82,16 @@ mod test {
let single_item_tuple = py_tuple!(py, ("elem"));
assert_eq!(
"elem",
single_item_tuple.get_item(0).extract::<&str>().unwrap()
single_item_tuple
.get_item(0)
.expect("failed to get item")
.extract::<&str>()
.unwrap()
);

let multi_item_tuple = py_tuple!(
py,
(
"elem1",
"elem2",
3,
4,
py_dict!(py, {"type": "user"}).unwrap()
)
("elem1", "elem2", 3, 4, py_dict!({"type": "user"}).unwrap())
);

assert_eq!(
Expand Down
46 changes: 46 additions & 0 deletions tests/test_literals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use pyo3::prelude::*;
use pyo3::{py_dict, py_run, py_tuple};

#[test]
fn test_dict_literal() {
let gil = Python::acquire_gil();
let py = gil.python();

let dict = py_dict!(py => {"key": "value"}).expect("failed to create dict");
assert_eq!(
"value",
dict.get_item("key").unwrap().extract::<String>().unwrap()
);

let value = "value";
let multi_elem_dict =
py_dict!(py => {"key1": value, 143: "abcde"}).expect("failed to create dict");
assert_eq!(
"value",
multi_elem_dict
.get_item("key1")
.unwrap()
.extract::<&str>()
.unwrap()
);
assert_eq!(
"abcde",
multi_elem_dict
.get_item(143)
.unwrap()
.extract::<&str>()
.unwrap()
);

let keys = &["key1", "key2"];

let expr_dict =
py_dict!(py => { keys[0]: "value1", keys[1]: "value2", 3-7: py_tuple!(py, ("elem1", "elem2", 3)) })
.expect("failed to create dict");

py_run!(
py,
expr_dict,
"assert expr_dict == {'key1': 'value1', 'key2': 'value2', -4: ('elem1', 'elem2', 3)}"
);
}

0 comments on commit 85b8dbc

Please sign in to comment.