Skip to content

Commit

Permalink
🎨 Modularized code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikrk committed Feb 2, 2024
1 parent c84c78c commit c2ddac3
Showing 1 changed file with 99 additions and 68 deletions.
167 changes: 99 additions & 68 deletions crates/client/src/fuzzer/snapshot_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{error::Error, fs::File, io::Read};

use anchor_lang::anchor_syn::{AccountField, Ty};
use cargo_metadata::camino::Utf8PathBuf;
use proc_macro2::{Span, TokenStream};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::parse::{Error as ParseError, Result as ParseResult};
use syn::spanned::Spanned;
Expand Down Expand Up @@ -40,86 +40,102 @@ pub fn generate_snapshots_code(code_path: Vec<(String, Utf8PathBuf)>) -> Result<
.content
.ok_or("the content of program module is missing")?;

let mut ix_ctx_pairs = Vec::new();
for item in items {
// Iterate through items in program module and find functions with the Context<_> parameter. Save the function name and the Context's inner type.
if let syn::Item::Fn(func) = item {
let func_name = &func.sig.ident;
let first_param_type = if let Some(param) = func.sig.inputs.into_iter().next() {
let mut ty = None::<GenericArgument>;
if let syn::FnArg::Typed(t) = param {
if let syn::Type::Path(tp) = *t.ty.clone() {
if let Some(seg) = tp.path.segments.into_iter().next() {
if let PathArguments::AngleBracketed(arg) = seg.arguments {
ty = arg.args.first().cloned();
}
}
}
}
ty
} else {
None
};
let ix_ctx_pairs = get_ix_ctx_pairs(&items)?;

let first_param_type = first_param_type.ok_or(format!(
"The function {} does not have the Context parameter and is malformed.",
func_name
))?;
let (structs, impls) = get_snapshot_structs_and_impls(code, &ix_ctx_pairs)?;

ix_ctx_pairs.push((func_name.clone(), first_param_type));
}
}

// Find definition of each Context struct and create new struct with fields wrapped in Option<_>
let mut structs = String::new();
let mut desers = String::new();
let parse_result = syn::parse_file(code).map_err(|e| e.to_string())?;
for pair in ix_ctx_pairs {
let mut ty = None;
if let GenericArgument::Type(syn::Type::Path(tp)) = &pair.1 {
ty = tp.path.get_ident().cloned();
// TODO add support for types with fully qualified path such as ix::Initialize
}
let ty = ty.ok_or(format!("malformed parameters of {} instruction", pair.0))?;

// recursively find the context struct and create a new version with wrapped fields into Option
if let Some(ctx) = get_ctx_struct(&parse_result.items, &ty) {
let fields_parsed = if let Fields::Named(f) = ctx.fields.clone() {
let field_deser: ParseResult<Vec<AccountField>> =
f.named.iter().map(parse_account_field).collect();
field_deser
} else {
Err(ParseError::new(
ctx.fields.span(),
"Context struct parse errror.",
))
}
.map_err(|e| e.to_string())?;

let wrapped_struct = wrap_fields_in_option(ctx, &fields_parsed).unwrap();
let deser_code = deserialize_ctx_struct_anchor(ctx, &fields_parsed)
.map_err(|e| e.to_string())?;
// let deser_code = deserialize_ctx_struct(ctx).unwrap();
structs = format!("{}{}", structs, wrapped_struct.into_token_stream());
desers = format!("{}{}", desers, deser_code.into_token_stream());
} else {
return Err(format!("The Context struct {} was not found", ty));
}
}
let use_statements = quote! {
use trdelnik_client::anchor_lang::{prelude::*, self};
use trdelnik_client::anchor_lang::solana_program::instruction::AccountMeta;
use trdelnik_client::fuzzing::{get_account_infos_option, FuzzingError};
}
.into_token_stream();
Ok(format!("{}{}{}", use_statements, structs, desers))
Ok(format!("{}{}{}", use_statements, structs, impls))
});

code.into_iter().collect()
}

/// Creates new snapshot structs with fields wrapped in Option<_> if approriate and the
/// respective implementations with snapshot deserialization methods
fn get_snapshot_structs_and_impls(
code: &str,
ix_ctx_pairs: &[(Ident, GenericArgument)],
) -> Result<(String, String), String> {
let mut structs = String::new();
let mut impls = String::new();
let parse_result = syn::parse_file(code).map_err(|e| e.to_string())?;
for pair in ix_ctx_pairs {
let mut ty = None;
if let GenericArgument::Type(syn::Type::Path(tp)) = &pair.1 {
ty = tp.path.get_ident().cloned();
// TODO add support for types with fully qualified path such as ix::Initialize
}
let ty = ty.ok_or(format!("malformed parameters of {} instruction", pair.0))?;

// recursively find the context struct and create a new version with wrapped fields into Option
if let Some(ctx) = find_ctx_struct(&parse_result.items, &ty) {
let fields_parsed = if let Fields::Named(f) = ctx.fields.clone() {
let field_deser: ParseResult<Vec<AccountField>> =
f.named.iter().map(parse_account_field).collect();
field_deser
} else {
Err(ParseError::new(
ctx.fields.span(),
"Context struct parse errror.",
))
}
.map_err(|e| e.to_string())?;

let wrapped_struct = wrap_fields_in_option(ctx, &fields_parsed).unwrap();
let deser_code =
deserialize_ctx_struct_anchor(ctx, &fields_parsed).map_err(|e| e.to_string())?;
// let deser_code = deserialize_ctx_struct(ctx).unwrap();
structs = format!("{}{}", structs, wrapped_struct.into_token_stream());
impls = format!("{}{}", impls, deser_code.into_token_stream());
} else {
return Err(format!("The Context struct {} was not found", ty));
}
}

Ok((structs, impls))
}

/// Iterates through items and finds functions with the Context<_> parameter. Returns pairs with the function name and the Context's inner type.
fn get_ix_ctx_pairs(items: &[Item]) -> Result<Vec<(Ident, GenericArgument)>, String> {
let mut ix_ctx_pairs = Vec::new();
for item in items {
if let syn::Item::Fn(func) = item {
let func_name = &func.sig.ident;
let first_param_type = if let Some(param) = func.sig.inputs.iter().next() {
let mut ty = None::<GenericArgument>;
if let syn::FnArg::Typed(t) = param {
if let syn::Type::Path(tp) = *t.ty.clone() {
if let Some(seg) = tp.path.segments.iter().next() {
if let PathArguments::AngleBracketed(arg) = &seg.arguments {
ty = arg.args.first().cloned();
}
}
}
}
ty
} else {
None
};

let first_param_type = first_param_type.ok_or(format!(
"The function {} does not have the Context parameter and is malformed.",
func_name
))?;

ix_ctx_pairs.push((func_name.clone(), first_param_type));
}
}
Ok(ix_ctx_pairs)
}

/// Recursively find a struct with a given `name`
fn get_ctx_struct<'a>(items: &'a Vec<syn::Item>, name: &'a syn::Ident) -> Option<&'a ItemStruct> {
fn find_ctx_struct<'a>(items: &'a Vec<syn::Item>, name: &'a syn::Ident) -> Option<&'a ItemStruct> {
for item in items {
if let Item::Struct(struct_item) = item {
if struct_item.ident == *name {
Expand All @@ -132,7 +148,7 @@ fn get_ctx_struct<'a>(items: &'a Vec<syn::Item>, name: &'a syn::Ident) -> Option
for item in items {
if let Item::Mod(mod_item) = item {
if let Some((_, items)) = &mod_item.content {
let r = get_ctx_struct(items, name);
let r = find_ctx_struct(items, name);
if r.is_some() {
return r;
}
Expand All @@ -143,6 +159,9 @@ fn get_ctx_struct<'a>(items: &'a Vec<syn::Item>, name: &'a syn::Ident) -> Option
None
}

/// Determines if an Account should be wrapped into the `Option` type.
/// The function returns true if the account has the init or close constraints set
/// and is not already wrapped into the `Option` type.
fn is_optional(parsed_field: &AccountField) -> bool {
let is_optional = match parsed_field {
AccountField::Field(field) => field.is_optional,
Expand All @@ -156,6 +175,10 @@ fn is_optional(parsed_field: &AccountField) -> bool {
(constraints.init.is_some() || constraints.is_close()) && !is_optional
}

/// Determines if an Accout should be deserialized as optional.
/// The function returns true if the account has the init or close constraints set
/// or if it is explicitly optional (it was wrapped into the `Option` type already
/// in the definition of it's corresponding context structure).
fn deserialize_as_option(parsed_field: &AccountField) -> bool {
let is_optional = match parsed_field {
AccountField::Field(field) => field.is_optional,
Expand Down Expand Up @@ -211,6 +234,7 @@ fn wrap_fields_in_option(
Ok(generated_struct.to_token_stream())
}

/// Generates code to deserialize the snapshot structs.
fn deserialize_ctx_struct_anchor(
snapshot_struct: &ItemStruct,
parsed_fields: &[AccountField],
Expand Down Expand Up @@ -266,6 +290,7 @@ fn deserialize_ctx_struct_anchor(
Ok(generated_deser_impl.to_token_stream())
}

/// Get the identifier (name) of the passed sysvar type.
fn sysvar_to_ident(sysvar: &anchor_lang::anchor_syn::SysvarTy) -> String {
let str = match sysvar {
anchor_lang::anchor_syn::SysvarTy::Clock => "Clock",
Expand All @@ -282,6 +307,9 @@ fn sysvar_to_ident(sysvar: &anchor_lang::anchor_syn::SysvarTy) -> String {
str.into()
}

/// Converts passed account type to token streams. The function returns a pair of streams where the first
/// variable in the pair is the type itself and the second is a fully qualified function to deserialize
/// the given type.
pub fn ty_to_tokens(ty: &anchor_lang::anchor_syn::Ty) -> Option<(TokenStream, TokenStream)> {
let (return_type, deser_method) = match ty {
Ty::AccountInfo | Ty::UncheckedAccount => return None,
Expand Down Expand Up @@ -342,6 +370,7 @@ pub fn ty_to_tokens(ty: &anchor_lang::anchor_syn::Ty) -> Option<(TokenStream, To
Some((return_type, deser_method))
}

/// Generates the code necessary to deserialize an account
fn deserialize_account_tokens(
name: &syn::Ident,
is_optional: bool,
Expand Down Expand Up @@ -369,6 +398,7 @@ fn deserialize_account_tokens(
}
}

/// Generates the code used with raw accounts as AccountInfo
fn acc_info_tokens(name: &syn::Ident) -> TokenStream {
quote! {
let #name = accounts_iter
Expand All @@ -377,6 +407,7 @@ fn acc_info_tokens(name: &syn::Ident) -> TokenStream {
}
}

/// Checks if the program attribute is present
fn has_program_attribute(attrs: &Vec<Attribute>) -> bool {
for attr in attrs {
if attr.path.is_ident("program") {
Expand Down

0 comments on commit c2ddac3

Please sign in to comment.