Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add tokio async support #150

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions borsh-derive-internal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod enum_discriminant_map;
mod enum_ser;
mod struct_de;
mod struct_ser;
mod tokio;
mod union_de;
mod union_ser;

Expand All @@ -17,3 +18,10 @@ pub use struct_de::struct_de;
pub use struct_ser::struct_ser;
pub use union_de::union_de;
pub use union_ser::union_ser;

pub use tokio::enum_de as tokio_enum_de;
pub use tokio::enum_ser as tokio_enum_ser;
pub use tokio::struct_de as tokio_struct_de;
pub use tokio::struct_ser as tokio_struct_ser;
pub use tokio::union_de as tokio_union_de;
pub use tokio::union_ser as tokio_union_ser;
13 changes: 13 additions & 0 deletions borsh-derive-internal/src/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
mod enum_de;
mod enum_ser;
mod struct_de;
mod struct_ser;
mod union_de;
mod union_ser;

pub use enum_de::enum_de;
pub use enum_ser::enum_ser;
pub use struct_de::struct_de;
pub use struct_ser::struct_ser;
pub use union_de::union_de;
pub use union_ser::union_ser;
113 changes: 113 additions & 0 deletions borsh-derive-internal/src/tokio/enum_de.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Fields, Ident, ItemEnum, WhereClause};

use crate::{
attribute_helpers::{contains_initialize_with, contains_skip},
enum_discriminant_map::discriminant_map,
};

pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut where_clause = where_clause.map_or_else(
|| WhereClause {
where_token: Default::default(),
predicates: Default::default(),
},
Clone::clone,
);
let init_method = contains_initialize_with(&input.attrs)?;
let mut variant_arms = TokenStream2::new();
let discriminants = discriminant_map(&input.variants);
for variant in input.variants.iter() {
let variant_ident = &variant.ident;
let discriminant = discriminants.get(variant_ident).unwrap();
let mut variant_header = TokenStream2::new();
match &variant.fields {
Fields::Named(fields) => {
for field in &fields.named {
let field_name = field.ident.as_ref().unwrap();
if contains_skip(&field.attrs) {
variant_header.extend(quote! {
#field_name: Default::default(),
});
} else {
let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::tokio::AsyncBorshDeserialize
})
.unwrap(),
);

variant_header.extend(quote! {
#field_name: #cratename::AsyncBorshDeserialize::deserialize_reader(reader).await?,
});
}
}
variant_header = quote! { { #variant_header }};
}
Fields::Unnamed(fields) => {
for field in fields.unnamed.iter() {
if contains_skip(&field.attrs) {
variant_header.extend(quote! { Default::default(), });
} else {
let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::tokio::AsyncBorshDeserialize
})
.unwrap(),
);

variant_header.extend(
quote! { #cratename::tokio::AsyncBorshDeserialize::deserialize_reader(reader).await?, },
);
}
}
variant_header = quote! { ( #variant_header )};
}
Fields::Unit => {}
}
variant_arms.extend(quote! {
if variant_tag == #discriminant { #name::#variant_ident #variant_header } else
});
}

let init = if let Some(method_ident) = init_method {
quote! {
return_value.#method_ident();
}
} else {
quote! {}
};

Ok(quote! {
#[async_trait::async_trait]
impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause {
async fn deserialize_reader<R: borsh::maybestd::io::Read>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let tag = <u8 as #cratename::de::BorshDeserialize>::deserialize_reader(reader).await?;
<Self as #cratename::de::EnumExt>::deserialize_variant(reader, tag).await
}
}

#[async_trait::async_trait]
impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
async fn deserialize_variant<R: #cratename::tokio::de::AsyncReader>(
reader: &mut R,
variant_tag: u8,
) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let mut return_value =
#variant_arms {
return Err(#cratename::maybestd::io::Error::new(
#cratename::maybestd::io::ErrorKind::InvalidInput,
#cratename::maybestd::format!("Unexpected variant tag: {:?}", variant_tag),
))
};
#init
Ok(return_value)
}
}
})
}
112 changes: 112 additions & 0 deletions borsh-derive-internal/src/tokio/enum_ser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use core::convert::TryFrom;

use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{Fields, Ident, ItemEnum, WhereClause};

use crate::{attribute_helpers::contains_skip, enum_discriminant_map::discriminant_map};

pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut where_clause = where_clause.map_or_else(
|| WhereClause {
where_token: Default::default(),
predicates: Default::default(),
},
Clone::clone,
);
let mut variant_idx_body = TokenStream2::new();
let mut fields_body = TokenStream2::new();
let discriminants = discriminant_map(&input.variants);
for variant in input.variants.iter() {
let variant_ident = &variant.ident;
let mut variant_header = TokenStream2::new();
let mut variant_body = TokenStream2::new();
let discriminant_value = discriminants.get(variant_ident).unwrap();
match &variant.fields {
Fields::Named(fields) => {
for field in &fields.named {
let field_name = field.ident.as_ref().unwrap();
if contains_skip(&field.attrs) {
variant_header.extend(quote! { _#field_name, });
continue;
} else {
let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::tokio::ser::AsyncBorshSerialize
})
.unwrap(),
);
variant_header.extend(quote! { #field_name, });
}
variant_body.extend(quote! {
#cratename::tokio::AsyncBorshSerialize::serialize(#field_name, writer).await?;
})
}
variant_header = quote! { { #variant_header }};
variant_idx_body.extend(quote!(
#name::#variant_ident { .. } => #discriminant_value,
));
}
Fields::Unnamed(fields) => {
for (field_idx, field) in fields.unnamed.iter().enumerate() {
let field_idx =
u32::try_from(field_idx).expect("up to 2^32 fields are supported");
if contains_skip(&field.attrs) {
let field_ident =
Ident::new(format!("_id{}", field_idx).as_str(), Span::call_site());
variant_header.extend(quote! { #field_ident, });
continue;
} else {
let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::tokio::ser::AsyncBorshSerialize
})
.unwrap(),
);

let field_ident =
Ident::new(format!("id{}", field_idx).as_str(), Span::call_site());
variant_header.extend(quote! { #field_ident, });
variant_body.extend(quote! {
#cratename::tokio::AsyncBorshSerialize::serialize(#field_ident, writer).await?;
})
}
}
variant_header = quote! { ( #variant_header )};
variant_idx_body.extend(quote!(
#name::#variant_ident(..) => #discriminant_value,
));
}
Fields::Unit => {
variant_idx_body.extend(quote!(
#name::#variant_ident => #discriminant_value,
));
}
}
fields_body.extend(quote!(
#name::#variant_ident #variant_header => {
#variant_body
}
))
}
Ok(quote! {
#[async_trait::async_trait]
impl #impl_generics #cratename::tokio::ser::AsyncBorshSerialize for #name #ty_generics #where_clause {
async fn serialize<W: #cratename::tokio::ser::AsyncWriter>(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::maybestd::io::Error> {
let variant_idx: u8 = match self {
#variant_idx_body
};
writer.write_all(&variant_idx.to_le_bytes()).await?;

match self {
#fields_body
}
Ok(())
}
}
})
}
19 changes: 19 additions & 0 deletions borsh-derive-internal/src/tokio/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#![recursion_limit = "128"]
// TODO: re-enable this lint when we bump msrv to 1.58
#![allow(clippy::uninlined_format_args)]

mod attribute_helpers;
mod enum_de;
mod enum_discriminant_map;
mod enum_ser;
mod struct_de;
mod struct_ser;
mod union_de;
mod union_ser;

pub use enum_de::enum_de;
pub use enum_ser::enum_ser;
pub use struct_de::struct_de;
pub use struct_ser::struct_ser;
pub use union_de::union_de;
pub use union_ser::union_ser;
85 changes: 85 additions & 0 deletions borsh-derive-internal/src/tokio/struct_de.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Fields, Ident, ItemStruct, WhereClause};

use crate::attribute_helpers::{contains_initialize_with, contains_skip};

pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut where_clause = where_clause.map_or_else(
|| WhereClause {
where_token: Default::default(),
predicates: Default::default(),
},
Clone::clone,
);
let init_method = contains_initialize_with(&input.attrs)?;
let return_value = match &input.fields {
Fields::Named(fields) => {
let mut body = TokenStream2::new();
for field in &fields.named {
let field_name = field.ident.as_ref().unwrap();
let delta = if contains_skip(&field.attrs) {
quote! {
#field_name: Default::default(),
}
} else {
let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::tokio::de::AsyncBorshDeserialize
})
.unwrap(),
);

quote! {
#field_name: #cratename::tokio::de::AsyncBorshDeserialize::deserialize_reader(reader).await?,
}
};
body.extend(delta);
}
quote! {
Self { #body }
}
}
Fields::Unnamed(fields) => {
let mut body = TokenStream2::new();
for _ in 0..fields.unnamed.len() {
let delta = quote! {
#cratename::tokio::de::AsyncBorshDeserialize::deserialize_reader(reader).await?,
};
body.extend(delta);
}
quote! {
Self( #body )
}
}
Fields::Unit => {
quote! {
Self {}
}
}
};
if let Some(method_ident) = init_method {
Ok(quote! {
#[async_trait::async_trait]
impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause {
async fn deserialize_reader<R: #cratename::tokio::de::AsyncReader>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let mut return_value = #return_value;
return_value.#method_ident();
Ok(return_value)
}
}
})
} else {
Ok(quote! {
#[async_trait::async_trait]
impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause {
async fn deserialize_reader<R: #cratename::tokio::de::AsyncReader>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
Ok(#return_value)
}
}
})
}
}
Loading