Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: ser/de enum discriminant #138

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions borsh-derive-internal/src/enum_de.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use core::convert::TryFrom;

use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Fields, Ident, ItemEnum, WhereClause};

use crate::attribute_helpers::{contains_initialize_with, contains_skip};
use crate::{
attribute_helpers::{contains_initialize_with, contains_skip},
enum_discriminant_map::discriminant_map,
};

pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
Expand All @@ -18,9 +19,10 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
);
let init_method = contains_initialize_with(&input.attrs)?;
let mut variant_arms = TokenStream2::new();
for (variant_idx, variant) in input.variants.iter().enumerate() {
let variant_idx = u8::try_from(variant_idx).expect("up to 256 enum variants are supported");
let discriminants = discriminant_map(&input.variants);
for variant in input.variants.iter() {
let variant_ident = &variant.ident;
let discriminant = discriminants.get(variant_ident).unwrap();
let mut variant_header = TokenStream2::new();
match &variant.fields {
Fields::Named(fields) => {
Expand Down Expand Up @@ -69,7 +71,7 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
Fields::Unit => {}
}
variant_arms.extend(quote! {
#variant_idx => #name::#variant_ident #variant_header ,
if variant_tag == #discriminant { #name::#variant_ident #variant_header } else
});
}

Expand All @@ -92,13 +94,13 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
fn deserialize_variant<R: borsh::maybestd::io::Read>(
reader: &mut R,
variant_idx: u8,
variant_tag: u8,
) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let mut return_value = match variant_idx {
#variant_arms
_ => return Err(#cratename::maybestd::io::Error::new(
let mut return_value =
#variant_arms {
return Err(#cratename::maybestd::io::Error::new(
#cratename::maybestd::io::ErrorKind::InvalidInput,
#cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx),
#cratename::maybestd::format!("Unexpected variant tag: {:?}", variant_tag),
))
};
#init
Expand Down
24 changes: 24 additions & 0 deletions borsh-derive-internal/src/enum_discriminant_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::collections::HashMap;

use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, Variant};

/// Calculates the discriminant that will be assigned by the compiler.
/// See: https://doc.rust-lang.org/reference/items/enumerations.html#assigning-discriminant-values
pub fn discriminant_map(variants: &Punctuated<Variant, Comma>) -> HashMap<Ident, TokenStream> {
let mut map = HashMap::new();

let mut next_discriminant_if_not_specified = quote! {0};

for variant in variants {
let this_discriminant = variant.discriminant.clone().map_or_else(
|| quote! { #next_discriminant_if_not_specified },
|(_, e)| quote! { #e },
);
next_discriminant_if_not_specified = quote! { #this_discriminant + 1 };
map.insert(variant.ident.clone(), this_discriminant);
}

map
}
13 changes: 7 additions & 6 deletions borsh-derive-internal/src/enum_ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{Fields, Ident, ItemEnum, WhereClause};

use crate::attribute_helpers::contains_skip;
use crate::{attribute_helpers::contains_skip, enum_discriminant_map::discriminant_map};

pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
Expand All @@ -18,11 +18,12 @@ pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
);
let mut variant_idx_body = TokenStream2::new();
let mut fields_body = TokenStream2::new();
for (variant_idx, variant) in input.variants.iter().enumerate() {
let variant_idx = u8::try_from(variant_idx).expect("up to 256 enum variants are supported");
let discriminants = discriminant_map(&input.variants);
for variant in input.variants.iter() {
let variant_ident = &variant.ident;
let mut variant_header = TokenStream2::new();
let mut variant_body = TokenStream2::new();
let discriminant_value = discriminants.get(variant_ident).unwrap();
match &variant.fields {
Fields::Named(fields) => {
for field in &fields.named {
Expand All @@ -46,7 +47,7 @@ pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
}
variant_header = quote! { { #variant_header }};
variant_idx_body.extend(quote!(
#name::#variant_ident { .. } => #variant_idx,
#name::#variant_ident { .. } => #discriminant_value,
));
}
Fields::Unnamed(fields) => {
Expand Down Expand Up @@ -77,12 +78,12 @@ pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
}
variant_header = quote! { ( #variant_header )};
variant_idx_body.extend(quote!(
#name::#variant_ident(..) => #variant_idx,
#name::#variant_ident(..) => #discriminant_value,
));
}
Fields::Unit => {
variant_idx_body.extend(quote!(
#name::#variant_ident => #variant_idx,
#name::#variant_ident => #discriminant_value,
));
}
}
Expand Down
1 change: 1 addition & 0 deletions borsh-derive-internal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

mod attribute_helpers;
mod enum_de;
mod enum_discriminant_map;
mod enum_ser;
mod struct_de;
mod struct_ser;
Expand Down
2 changes: 1 addition & 1 deletion borsh/tests/test_de_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn test_invalid_enum_variant() {
let bytes = vec![123];
assert_eq!(
A::try_from_slice(&bytes).unwrap_err().to_string(),
"Unexpected variant index: 123"
"Unexpected variant tag: 123"
);
}

Expand Down
35 changes: 35 additions & 0 deletions borsh/tests/test_simple_structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,41 @@ struct F2<'b> {
aa: Vec<A<'b>>,
}

#[derive(BorshSerialize, BorshDeserialize, PartialEq, Eq, Clone, Copy, Debug)]
enum X {
A,
B = 20,
C,
D,
E = 10,
F,
}

#[test]
fn test_discriminant_serialization() {
let values = vec![X::A, X::B, X::C, X::D, X::E, X::F];
for value in values {
assert_eq!(value.try_to_vec().unwrap(), [value as u8]);
}
}

#[test]
fn test_discriminant_deserialization() {
let values = vec![X::A, X::B, X::C, X::D, X::E, X::F];
for value in values {
assert_eq!(
<X as BorshDeserialize>::try_from_slice(&[value as u8]).unwrap(),
value,
);
}
}

#[test]
#[should_panic = "Unexpected variant tag: 2"]
fn test_deserialize_invalid_discriminant() {
<X as BorshDeserialize>::try_from_slice(&[2]).unwrap();
}

#[test]
fn test_simple_struct() {
let mut map: HashMap<String, String> = HashMap::new();
Expand Down