diff --git a/Cargo.lock b/Cargo.lock index 8f1fc1a3..cdac0c34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,15 +109,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "libc" -version = "0.2.138" +version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d7e329c562c5dfab7a46a2afabc8b987ab9a4834c9d1ca04dc54c1546cef8" +checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" [[package]] name = "opaque-debug" @@ -134,33 +134,33 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "serde" -version = "1.0.151" +version = "1.0.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fed41fc1a24994d044e6db6935e69511a1153b52c15eb42493b26fa87feba0" +checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.151" +version = "1.0.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "255abe9a125a985c05190d687b320c12f9b1f0b99445e608c21ba0782c719ad8" +checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" dependencies = [ "proc-macro2", "quote", @@ -169,9 +169,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" dependencies = [ "itoa", "ryu", @@ -180,27 +180,15 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.107" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +checksum = "bcc02725fd69ab9f26eab07fad303e2497fad6fb9eba4f96c4d1687bdf704ad9" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] -[[package]] -name = "synstructure" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "unicode-xid", -] - [[package]] name = "typenum" version = "1.16.0" @@ -209,15 +197,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicode-ident" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" - -[[package]] -name = "unicode-xid" -version = "0.2.4" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "version_check" @@ -250,5 +232,4 @@ dependencies = [ "proc-macro2", "quote", "syn", - "synstructure", ] diff --git a/fiat-constify/Cargo.toml b/fiat-constify/Cargo.toml index 2c2b46a3..c3af946d 100644 --- a/fiat-constify/Cargo.toml +++ b/fiat-constify/Cargo.toml @@ -18,4 +18,4 @@ rust-version = "1.56" [dependencies] proc-macro2 = "1" quote = "1" -syn = { version = "1", features = ["extra-traits", "full"] } +syn = { version = "2", features = ["extra-traits", "full"] } diff --git a/fiat-constify/src/main.rs b/fiat-constify/src/main.rs index c26a5c2c..b697190d 100644 --- a/fiat-constify/src/main.rs +++ b/fiat-constify/src/main.rs @@ -5,16 +5,16 @@ #![allow(clippy::single_match, clippy::new_without_default)] -use proc_macro2::Span; +use proc_macro2::{Literal, Span}; use quote::{quote, ToTokens}; use std::{collections::BTreeMap as Map, env, fs, ops::Deref}; use syn::{ punctuated::Punctuated, - token::{Bang, Brace, Bracket, Colon, Const, Eq, Let, Mut, Paren, Pound, RArrow, Semi}, + token::{Brace, Bracket, Colon, Const, Eq, Let, Mut, Not, Paren, Pound, RArrow, Semi}, AttrStyle, Attribute, Block, Expr, ExprAssign, ExprCall, ExprLit, ExprPath, ExprReference, - ExprRepeat, ExprTuple, FnArg, Ident, Item, ItemFn, ItemType, Lit, LitInt, Local, Pat, PatIdent, - PatTuple, PatType, Path, PathArguments, PathSegment, ReturnType, Stmt, Type, TypeArray, - TypePath, TypeReference, TypeTuple, UnOp, + ExprRepeat, ExprTuple, FnArg, Ident, Item, ItemFn, ItemType, Lit, LitInt, Local, LocalInit, + MacroDelimiter, Meta, MetaList, Pat, PatIdent, PatTuple, PatType, Path, PathArguments, + PathSegment, ReturnType, Stmt, Type, TypeArray, TypePath, TypeReference, TypeTuple, UnOp, }; fn main() -> Result<(), Box> { @@ -59,29 +59,33 @@ fn main() -> Result<(), Box> { /// Build a toplevel attribute with the given name and comma-separated values. fn build_attribute(name: &str, values: &[&str]) -> Attribute { - let span = Span::call_site(); let values = values .iter() .map(|value| build_path(value)) .collect::>(); + let path = build_path(name); + let tokens = quote! { #(#values),* }; + let delimiter = MacroDelimiter::Paren(Paren::default()); Attribute { - pound_token: Pound { spans: [span] }, - style: AttrStyle::Inner(Bang { spans: [span] }), - bracket_token: Bracket { span }, - path: build_path(name), - tokens: quote! { (#(#values),*) }, + pound_token: Pound::default(), + style: AttrStyle::Inner(Not::default()), + bracket_token: Bracket::default(), + meta: Meta::List(MetaList { + path, + delimiter, + tokens, + }), } } /// Parse a path from a double-colon-delimited string. fn build_path(path: &str) -> Path { - let span = Span::call_site(); let mut segments = Punctuated::new(); for segment in path.split("::") { segments.push(PathSegment { - ident: Ident::new(segment, span), + ident: Ident::new(segment, Span::call_site()), arguments: PathArguments::None, }); } @@ -103,10 +107,8 @@ fn get_ident_from_pat(pat: &Pat) -> Ident { /// Rewrite a fiat-crypto generated `fn` as a `const fn`, making the necessary /// transformations to the code in order for it to work in that context. fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) { - let span = Span::call_site(); - // Mark function as being `const fn`. - func.sig.constness = Some(Const { span }); + func.sig.constness = Some(Const::default()); // Transform mutable arguments into return values. let mut inputs = Punctuated::new(); @@ -142,7 +144,6 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) { /// values for outputs, removing mutable references, and adding a return /// value/tuple. fn rewrite_fn_body(statements: &[Stmt], outputs: &Outputs, registry: &TypeRegistry) -> Block { - let span = Span::call_site(); let mut stmts = Vec::new(); stmts.extend(outputs.to_let_bindings(registry).into_iter()); @@ -156,7 +157,7 @@ fn rewrite_fn_body(statements: &[Stmt], outputs: &Outputs, registry: &TypeRegist stmts.push(outputs.to_return_value()); Block { - brace_token: Brace { span }, + brace_token: Brace::default(), stmts, } } @@ -165,7 +166,7 @@ fn rewrite_fn_body(statements: &[Stmt], outputs: &Outputs, registry: &TypeRegist /// operations into value assignments. fn rewrite_fn_stmt(stmt: &mut Stmt) { match stmt { - Stmt::Semi(expr, _) => match expr { + Stmt::Expr(expr, Some(_)) => match expr { Expr::Assign(ExprAssign { left, .. }) => match *left.clone() { Expr::Unary(unary) => { // Remove deref since we're removing mutable references @@ -225,20 +226,22 @@ fn rewrite_fn_call(mut call: ExprCall) -> Local { // Overwrite call arguments with the ones that aren't mutable references call.args = args; - let span = Span::call_site(); - let pat = Pat::Tuple(PatTuple { attrs: Vec::new(), - paren_token: Paren { span }, + paren_token: Paren::default(), elems: output, }); Local { attrs: Vec::new(), - let_token: Let { span }, + let_token: Let::default(), pat, - init: Some((Eq { spans: [span] }, Box::new(Expr::Call(call)))), - semi_token: Semi { spans: [span] }, + init: Some(LocalInit { + eq_token: Eq::default(), + expr: Box::new(Expr::Call(call)), + diverge: None, + }), + semi_token: Semi::default(), } } @@ -299,28 +302,30 @@ impl Outputs { /// Generate `let mut outN: Ty = ` bindings at the start /// of the function. pub fn to_let_bindings(&self, registry: &TypeRegistry) -> Vec { - let span = Span::call_site(); - self.0 .iter() .map(|(ident, ty)| { Stmt::Local(Local { attrs: Vec::new(), - let_token: Let { span }, + let_token: Let::default(), pat: Pat::Type(PatType { attrs: Vec::new(), pat: Box::new(Pat::Ident(PatIdent { attrs: Vec::new(), by_ref: None, - mutability: Some(Mut { span }), + mutability: Some(Mut::default()), ident: ident.clone(), subpat: None, })), - colon_token: Colon { spans: [span] }, + colon_token: Colon::default(), ty: Box::new(ty.clone()), }), - init: Some((Eq { spans: [span] }, Box::new(default_for(ty, registry)))), - semi_token: Semi { spans: [span] }, + init: Some(LocalInit { + eq_token: Eq::default(), + expr: Box::new(default_for(ty, registry)), + diverge: None, + }), + semi_token: Semi::default(), }) }) .collect() @@ -328,10 +333,7 @@ impl Outputs { /// Finish annotating outputs, updating the provided `Signature`. pub fn to_return_type(&self) -> ReturnType { - let span = Span::call_site(); - let rarrow = RArrow { - spans: [span, span], - }; + let rarrow = RArrow::default(); let ret = match self.0.len() { 0 => panic!("expected at least one output"), @@ -344,7 +346,7 @@ impl Outputs { } Type::Tuple(TypeTuple { - paren_token: Paren { span }, + paren_token: Paren::default(), elems, }) } @@ -355,8 +357,6 @@ impl Outputs { /// Generate the return value for the statement as a tuple of the outputs. pub fn to_return_value(&self) -> Stmt { - let span = Span::call_site(); - let mut elems = self.0.keys().map(|ident| { let mut segments = Punctuated::new(); segments.push(PathSegment { @@ -377,31 +377,33 @@ impl Outputs { }); if elems.len() == 1 { - Stmt::Expr(elems.next().unwrap()) + Stmt::Expr(elems.next().unwrap(), None) } else { - Stmt::Expr(Expr::Tuple(ExprTuple { - attrs: Vec::new(), - paren_token: Paren { span }, - elems: elems.collect(), - })) + Stmt::Expr( + Expr::Tuple(ExprTuple { + attrs: Vec::new(), + paren_token: Paren::default(), + elems: elems.collect(), + }), + None, + ) } } } /// Get a default value for the given type. fn default_for(ty: &Type, registry: &TypeRegistry) -> Expr { - let span = Span::call_site(); let zero = Expr::Lit(ExprLit { attrs: Vec::new(), - lit: Lit::Int(LitInt::new("0", span)), + lit: Lit::Int(LitInt::from(Literal::u8_unsuffixed(0))), }); match ty { Type::Array(TypeArray { len, .. }) => Expr::Repeat(ExprRepeat { attrs: Vec::new(), - bracket_token: Bracket { span }, + bracket_token: Bracket::default(), expr: Box::new(zero), - semi_token: Semi { spans: [span] }, + semi_token: Semi::default(), len: Box::new(len.clone()), }), Type::Path(TypePath { path, .. }) => { diff --git a/zeroize/derive/Cargo.toml b/zeroize/derive/Cargo.toml index 9f8a5887..cadc4438 100644 --- a/zeroize/derive/Cargo.toml +++ b/zeroize/derive/Cargo.toml @@ -17,8 +17,7 @@ proc-macro = true [dependencies] proc-macro2 = "1" quote = "1" -syn = "1" -synstructure = "0.12.2" +syn = {version = "2", features = ["full", "extra-traits"]} [package.metadata.docs.rs] rustdoc-args = ["--document-private-items"] diff --git a/zeroize/derive/src/lib.rs b/zeroize/derive/src/lib.rs index baf69901..650dacb6 100644 --- a/zeroize/derive/src/lib.rs +++ b/zeroize/derive/src/lib.rs @@ -4,90 +4,118 @@ #![forbid(unsafe_code)] #![warn(rust_2018_idioms, trivial_casts, unused_qualifications)] -use proc_macro2::TokenStream; -use quote::quote; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, + parse_quote, punctuated::Punctuated, token::Comma, - Attribute, Lit, Meta, NestedMeta, Result, WherePredicate, + Attribute, Data, DeriveInput, Expr, ExprLit, Field, Fields, GenericParam, Lit, Meta, Result, + Variant, WherePredicate, }; -use synstructure::{decl_derive, AddBounds, BindStyle, BindingInfo, VariantInfo}; - -decl_derive!( - [Zeroize, attributes(zeroize)] => - - /// Derive the `Zeroize` trait. - /// - /// Supports the following attributes: - /// - /// On the item level: - /// - `#[zeroize(drop)]`: *deprecated* use `ZeroizeOnDrop` instead - /// - `#[zeroize(bound = "T: MyTrait")]`: this replaces any trait bounds - /// inferred by zeroize-derive - /// - /// On the field level: - /// - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()` - derive_zeroize -); - -decl_derive!( - [ZeroizeOnDrop, attributes(zeroize)] => - - /// Derive the `ZeroizeOnDrop` trait. - /// - /// Supports the following attributes: - /// - /// On the field level: - /// - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()` - derive_zeroize_on_drop -); /// Name of zeroize-related attributes const ZEROIZE_ATTR: &str = "zeroize"; -/// Custom derive for `Zeroize` -fn derive_zeroize(mut s: synstructure::Structure<'_>) -> TokenStream { - let attributes = ZeroizeAttrs::parse(&s); - - if let Some(bounds) = attributes.bound { - s.add_bounds(AddBounds::None); +/// Derive the `Zeroize` trait. +/// +/// Supports the following attributes: +/// +/// On the item level: +/// - `#[zeroize(drop)]`: *deprecated* use `ZeroizeOnDrop` instead +/// - `#[zeroize(bound = "T: MyTrait")]`: this replaces any trait bounds +/// inferred by zeroize-derive +/// +/// On the field level: +/// - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()` +#[proc_macro_derive(Zeroize, attributes(zeroize))] +pub fn derive_zeroize(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + derive_zeroize_impl(syn::parse_macro_input!(input as DeriveInput)).into() +} - for bound in bounds.0 { - s.add_where_predicate(bound); +fn derive_zeroize_impl(input: DeriveInput) -> TokenStream { + let attributes = ZeroizeAttrs::parse(&input); + + let extra_bounds = match attributes.bound { + Some(bounds) => bounds.0, + None => { + let mut out: Punctuated = Default::default(); + for param in &input.generics.params { + if let GenericParam::Type(type_param) = param { + let type_name = &type_param.ident; + out.push(parse_quote! { #type_name: Zeroize }) + } + } + out } - } + }; + + let mut generics = input.generics.clone(); + generics.make_where_clause().predicates.extend(extra_bounds); + + let ty_name = &input.ident; - // NOTE: These are split into named functions to simplify testing with - // synstructure's `test_derive!` macro. - if attributes.drop { - derive_zeroize_with_drop(s) + let (impl_gen, type_gen, where_) = generics.split_for_impl(); + + let drop_impl = if attributes.drop { + quote! { + #[doc(hidden)] + impl #impl_gen Drop for #ty_name #type_gen #where_ { + fn drop(&mut self) { + self.zeroize() + } + } + } } else { - derive_zeroize_without_drop(s) + quote! {} + }; + + let zeroizers = generate_fields(&input, quote! { zeroize }); + let zeroize_impl = quote! { + impl #impl_gen ::zeroize::Zeroize for #ty_name #type_gen #where_ { + fn zeroize(&mut self) { + #zeroizers + } + } + }; + + quote! { + #zeroize_impl + #drop_impl } } -/// Custom derive for `ZeroizeOnDrop` -fn derive_zeroize_on_drop(mut s: synstructure::Structure<'_>) -> TokenStream { - let zeroizers = generate_fields(&mut s, quote! { zeroize_or_on_drop }); +/// Derive the `ZeroizeOnDrop` trait. +/// +/// Supports the following attributes: +/// +/// On the field level: +/// - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()` +#[proc_macro_derive(ZeroizeOnDrop, attributes(zeroize))] +pub fn derive_zeroize_on_drop(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + derive_zeroize_on_drop_impl(syn::parse_macro_input!(input as DeriveInput)).into() +} + +fn derive_zeroize_on_drop_impl(input: DeriveInput) -> TokenStream { + let zeroizers = generate_fields(&input, quote! { zeroize_or_on_drop }); + + let (impl_gen, type_gen, where_) = input.generics.split_for_impl(); + let name = input.ident.clone(); - let drop_impl = s.add_bounds(AddBounds::None).gen_impl(quote! { - gen impl Drop for @Self { + let drop_impl = quote! { + impl #impl_gen Drop for #name #type_gen #where_ { fn drop(&mut self) { - use zeroize::__internal::AssertZeroize; - use zeroize::__internal::AssertZeroizeOnDrop; - match self { - #zeroizers - } + use ::zeroize::__internal::AssertZeroize; + use ::zeroize::__internal::AssertZeroizeOnDrop; + #zeroizers } } - }); - - let zeroize_on_drop_impl = impl_zeroize_on_drop(&s); + }; + let zeroize_on_drop_impl = impl_zeroize_on_drop(&input); quote! { #drop_impl - #zeroize_on_drop_impl } } @@ -112,40 +140,42 @@ impl Parse for Bounds { impl ZeroizeAttrs { /// Parse attributes from the incoming AST - fn parse(s: &synstructure::Structure<'_>) -> Self { + fn parse(input: &DeriveInput) -> Self { let mut result = Self::default(); - for attr in s.ast().attrs.iter() { + for attr in &input.attrs { result.parse_attr(attr, None, None); } - for v in s.variants().iter() { - // only process actual enum variants here, as we don't want to process struct attributes twice - if v.prefix.is_some() { - for attr in v.ast().attrs.iter() { - result.parse_attr(attr, Some(v), None); + + match &input.data { + syn::Data::Enum(enum_) => { + for variant in &enum_.variants { + for attr in &variant.attrs { + result.parse_attr(attr, Some(variant), None); + } + for field in &variant.fields { + for attr in &field.attrs { + result.parse_attr(attr, Some(variant), Some(field)); + } + } } } - for binding in v.bindings().iter() { - for attr in binding.ast().attrs.iter() { - result.parse_attr(attr, Some(v), Some(binding)); + syn::Data::Struct(struct_) => { + for field in &struct_.fields { + for attr in &field.attrs { + result.parse_attr(attr, None, Some(field)); + } } } + syn::Data::Union(union_) => panic!("Unsupported untagged union {:?}", union_), } result } /// Parse attribute and handle `#[zeroize(...)]` attributes - fn parse_attr( - &mut self, - attr: &Attribute, - variant: Option<&VariantInfo<'_>>, - binding: Option<&BindingInfo<'_>>, - ) { - let meta_list = match attr - .parse_meta() - .unwrap_or_else(|e| panic!("error parsing attribute: {:?} ({})", attr, e)) - { + fn parse_attr(&mut self, attr: &Attribute, variant: Option<&Variant>, binding: Option<&Field>) { + let meta_list = match &attr.meta { Meta::List(list) => list, _ => return, }; @@ -155,29 +185,23 @@ impl ZeroizeAttrs { return; } - for nested_meta in &meta_list.nested { - if let NestedMeta::Meta(meta) = nested_meta { - self.parse_meta(meta, variant, binding); - } else { - panic!("malformed #[zeroize] attribute: {:?}", nested_meta); - } + for meta in attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap_or_else(|e| panic!("error parsing attribute: {:?} ({})", attr, e)) + { + self.parse_meta(&meta, variant, binding); } } /// Parse `#[zeroize(...)]` attribute metadata (e.g. `drop`) - fn parse_meta( - &mut self, - meta: &Meta, - variant: Option<&VariantInfo<'_>>, - binding: Option<&BindingInfo<'_>>, - ) { + fn parse_meta(&mut self, meta: &Meta, variant: Option<&Variant>, binding: Option<&Field>) { if meta.path().is_ident("drop") { assert!(!self.drop, "duplicate #[zeroize] drop flags"); match (variant, binding) { (_variant, Some(_binding)) => { // structs don't have a variant prefix, and only structs have bindings outside of a variant - let item_kind = match variant.and_then(|variant| variant.prefix) { + let item_kind = match variant { Some(_) => "enum", None => "struct", }; @@ -203,7 +227,7 @@ impl ZeroizeAttrs { match (variant, binding) { (_variant, Some(_binding)) => { // structs don't have a variant prefix, and only structs have bindings outside of a variant - let item_kind = match variant.and_then(|variant| variant.prefix) { + let item_kind = match variant { Some(_) => "enum", None => "struct", }; @@ -221,7 +245,10 @@ impl ZeroizeAttrs { )), (None, None) => { if let Meta::NameValue(meta_name_value) = meta { - if let Lit::Str(lit) = &meta_name_value.lit { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit), .. + }) = &meta_name_value.value + { if lit.value().is_empty() { self.bound = Some(Bounds(Punctuated::new())); } else { @@ -253,277 +280,278 @@ impl ZeroizeAttrs { } } -fn generate_fields(s: &mut synstructure::Structure<'_>, method: TokenStream) -> TokenStream { - s.bind_with(|_| BindStyle::RefMut); +fn field_ident(n: usize, field: &Field) -> Ident { + if let Some(ref name) = field.ident { + name.clone() + } else { + format_ident!("__zeroize_field_{}", n) + } +} - s.filter_variants(|vi| { - let result = filter_skip(vi.ast().attrs, true); +fn generate_fields(input: &DeriveInput, method: TokenStream) -> TokenStream { + let input_id = &input.ident; + let fields: Vec<_> = match input.data { + Data::Enum(ref enum_) => enum_ + .variants + .iter() + .filter_map(|variant| { + if attr_skip(&variant.attrs) { + if variant.fields.iter().any(|field| attr_skip(&field.attrs)) { + panic!("duplicate #[zeroize] skip flags") + } + None + } else { + let variant_id = &variant.ident; + Some((quote! { #input_id :: #variant_id }, &variant.fields)) + } + }) + .collect(), + Data::Struct(ref struct_) => vec![(quote! { #input_id }, &struct_.fields)], + Data::Union(ref union_) => panic!("Cannot generate fields for untagged union {:?}", union_), + }; + + let arms = fields.into_iter().map(|(name, fields)| { + let method_field = fields.iter().enumerate().filter_map(|(n, field)| { + if attr_skip(&field.attrs) { + None + } else { + let name = field_ident(n, field); + Some(quote! { #name.#method() }) + } + }); + + let field_bindings = fields + .iter() + .enumerate() + .map(|(n, field)| field_ident(n, field)); + + let binding = match fields { + Fields::Named(_) => quote! { + #name { #(#field_bindings),* } + }, + Fields::Unnamed(_) => quote! { + #name ( #(#field_bindings),* ) + }, + Fields::Unit => quote! { + #name + }, + }; - // check for duplicate `#[zeroize(skip)]` attributes in nested variants - for field in vi.ast().fields { - filter_skip(&field.attrs, result); + quote! { + #[allow(unused_variables)] + #binding => { + #(#method_field);* + } } + }); - result - }) - .filter(|bi| filter_skip(&bi.ast().attrs, true)) - .each(|bi| quote! { #bi.#method(); }) + quote! { + match self { + #(#arms),* + _ => {} + } + } } -fn filter_skip(attrs: &[Attribute], start: bool) -> bool { - let mut result = start; - - for attr in attrs.iter().filter_map(|attr| attr.parse_meta().ok()) { +fn attr_skip(attrs: &[Attribute]) -> bool { + let mut result = false; + for attr in attrs.iter().map(|attr| &attr.meta) { if let Meta::List(list) = attr { if list.path.is_ident(ZEROIZE_ATTR) { - for nested in list.nested { - if let NestedMeta::Meta(Meta::Path(path)) = nested { + for meta in list + .parse_args_with(Punctuated::::parse_terminated) + .unwrap_or_else(|e| panic!("error parsing attribute: {:?} ({})", list, e)) + { + if let Meta::Path(path) = meta { if path.is_ident("skip") { - assert!(result, "duplicate #[zeroize] skip flags"); - result = false; + assert!(!result, "duplicate #[zeroize] skip flags"); + result = true; } } } } } } - result } -/// Custom derive for `Zeroize` (without `Drop`) -fn derive_zeroize_without_drop(mut s: synstructure::Structure<'_>) -> TokenStream { - let zeroizers = generate_fields(&mut s, quote! { zeroize }); - - s.bound_impl( - quote!(zeroize::Zeroize), - quote! { - fn zeroize(&mut self) { - match self { - #zeroizers - } - } - }, - ) -} - -/// Custom derive for `Zeroize` and `Drop` -fn derive_zeroize_with_drop(s: synstructure::Structure<'_>) -> TokenStream { - let drop_impl = s.gen_impl(quote! { - gen impl Drop for @Self { - fn drop(&mut self) { - self.zeroize(); - } - } - }); - - let zeroize_impl = derive_zeroize_without_drop(s); - +fn impl_zeroize_on_drop(input: &DeriveInput) -> TokenStream { + let name = input.ident.clone(); + let (impl_gen, type_gen, where_) = input.generics.split_for_impl(); quote! { - #zeroize_impl - #[doc(hidden)] - #drop_impl + impl #impl_gen ::zeroize::ZeroizeOnDrop for #name #type_gen #where_ {} } } -fn impl_zeroize_on_drop(s: &synstructure::Structure<'_>) -> TokenStream { - #[allow(unused_qualifications)] - s.bound_impl(quote!(zeroize::ZeroizeOnDrop), Option::::None) -} - #[cfg(test)] mod tests { use super::*; - use syn::parse_str; - use synstructure::{test_derive, Structure}; + + #[track_caller] + fn test_derive( + f: impl Fn(DeriveInput) -> TokenStream, + input: TokenStream, + expected_output: TokenStream, + ) { + let output = f(syn::parse2(input).unwrap()); + assert_eq!(format!("{output}"), format!("{expected_output}")); + } + + #[track_caller] + fn parse_zeroize_test(unparsed: &str) -> TokenStream { + derive_zeroize_impl(syn::parse_str(unparsed).expect("Failed to parse test input")) + } #[test] fn zeroize_without_drop() { - test_derive! { - derive_zeroize_without_drop { + test_derive( + derive_zeroize_impl, + quote! { struct Z { a: String, b: Vec, c: [u8; 3], } - } - expands to { - #[allow(non_upper_case_globals)] - #[doc(hidden)] - const _DERIVE_zeroize_Zeroize_FOR_Z: () = { - extern crate zeroize; - impl zeroize::Zeroize for Z { - fn zeroize(&mut self) { - match self { - Z { - a: ref mut __binding_0, - b: ref mut __binding_1, - c: ref mut __binding_2, - } => { - { __binding_0.zeroize(); } - { __binding_1.zeroize(); } - { __binding_2.zeroize(); } - } + }, + quote! { + impl ::zeroize::Zeroize for Z { + fn zeroize(&mut self) { + match self { + #[allow(unused_variables)] + Z { a, b, c } => { + a.zeroize(); + b.zeroize(); + c.zeroize() } + _ => {} } } - }; - } - no_build // tests the code compiles are in the `zeroize` crate - } + } + }, + ) } #[test] fn zeroize_with_drop() { - test_derive! { - derive_zeroize_with_drop { + test_derive( + derive_zeroize_impl, + quote! { + #[zeroize(drop)] struct Z { a: String, b: Vec, c: [u8; 3], } - } - expands to { - #[allow(non_upper_case_globals)] - #[doc(hidden)] - const _DERIVE_zeroize_Zeroize_FOR_Z: () = { - extern crate zeroize; - impl zeroize::Zeroize for Z { - fn zeroize(&mut self) { - match self { - Z { - a: ref mut __binding_0, - b: ref mut __binding_1, - c: ref mut __binding_2, - } => { - { __binding_0.zeroize(); } - { __binding_1.zeroize(); } - { __binding_2.zeroize(); } - } + }, + quote! { + impl ::zeroize::Zeroize for Z { + fn zeroize(&mut self) { + match self { + #[allow(unused_variables)] + Z { a, b, c } => { + a.zeroize(); + b.zeroize(); + c.zeroize() } + _ => {} } } - }; + } #[doc(hidden)] - #[allow(non_upper_case_globals)] - const _DERIVE_Drop_FOR_Z: () = { - impl Drop for Z { - fn drop(&mut self) { - self.zeroize(); - } + impl Drop for Z { + fn drop(&mut self) { + self.zeroize() } - }; - } - no_build // tests the code compiles are in the `zeroize` crate - } + } + }, + ) } #[test] fn zeroize_with_skip() { - test_derive! { - derive_zeroize_without_drop { + test_derive( + derive_zeroize_impl, + quote! { struct Z { a: String, b: Vec, #[zeroize(skip)] c: [u8; 3], } - } - expands to { - #[allow(non_upper_case_globals)] - #[doc(hidden)] - const _DERIVE_zeroize_Zeroize_FOR_Z: () = { - extern crate zeroize; - impl zeroize::Zeroize for Z { - fn zeroize(&mut self) { - match self { - Z { - a: ref mut __binding_0, - b: ref mut __binding_1, - .. - } => { - { __binding_0.zeroize(); } - { __binding_1.zeroize(); } - } + }, + quote! { + impl ::zeroize::Zeroize for Z { + fn zeroize(&mut self) { + match self { + #[allow(unused_variables)] + Z { a, b, c } => { + a.zeroize(); + b.zeroize() } + _ => {} } } - }; - } - no_build // tests the code compiles are in the `zeroize` crate - } + } + }, + ) } #[test] fn zeroize_with_bound() { - test_derive! { - derive_zeroize { + test_derive( + derive_zeroize_impl, + quote! { #[zeroize(bound = "T: MyTrait")] struct Z(T); - } - expands to { - #[allow(non_upper_case_globals)] - #[doc(hidden)] - const _DERIVE_zeroize_Zeroize_FOR_Z: () = { - extern crate zeroize; - impl zeroize::Zeroize for Z - where T: MyTrait - { - fn zeroize(&mut self) { - match self { - Z(ref mut __binding_0,) => { - { __binding_0.zeroize(); } - } + }, + quote! { + impl ::zeroize::Zeroize for Z where T: MyTrait { + fn zeroize(&mut self) { + match self { + #[allow(unused_variables)] + Z(__zeroize_field_0) => { + __zeroize_field_0.zeroize() } + _ => {} } } - }; - } - no_build // tests the code compiles are in the `zeroize` crate - } + } + }, + ) } #[test] fn zeroize_only_drop() { - test_derive! { - derive_zeroize_on_drop { + test_derive( + derive_zeroize_on_drop_impl, + quote! { struct Z { a: String, b: Vec, c: [u8; 3], } - } - expands to { - #[allow(non_upper_case_globals)] - const _DERIVE_Drop_FOR_Z: () = { - impl Drop for Z { - fn drop(&mut self) { - use zeroize::__internal::AssertZeroize; - use zeroize::__internal::AssertZeroizeOnDrop; - match self { - Z { - a: ref mut __binding_0, - b: ref mut __binding_1, - c: ref mut __binding_2, - } => { - { __binding_0.zeroize_or_on_drop(); } - { __binding_1.zeroize_or_on_drop(); } - { __binding_2.zeroize_or_on_drop(); } - } + }, + quote! { + impl Drop for Z { + fn drop(&mut self) { + use ::zeroize::__internal::AssertZeroize; + use ::zeroize::__internal::AssertZeroizeOnDrop; + match self { + #[allow(unused_variables)] + Z { a, b, c } => { + a.zeroize_or_on_drop(); + b.zeroize_or_on_drop(); + c.zeroize_or_on_drop() } + _ => {} } } - }; - #[allow(non_upper_case_globals)] + } #[doc(hidden)] - const _DERIVE_zeroize_ZeroizeOnDrop_FOR_Z: () = { - extern crate zeroize; - impl zeroize::ZeroizeOnDrop for Z {} - }; - } - no_build // tests the code compiles are in the `zeroize` crate - } + impl ::zeroize::ZeroizeOnDrop for Z {} + }, + ) } #[test] @@ -801,10 +829,4 @@ mod tests { struct Z(T); )); } - - fn parse_zeroize_test(unparsed: &str) -> TokenStream { - derive_zeroize(Structure::new( - &parse_str(unparsed).expect("Failed to parse test input"), - )) - } }