Skip to content

Commit

Permalink
♻️ Storing snapshot struct name
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikrk authored and lukacan committed Feb 28, 2024
1 parent a90a39e commit 55ab963
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions crates/client/src/fuzzer/snapshot_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn get_snapshot_structs_and_impls(
let mut impls = String::new();
let mut type_aliases = String::new();
let parse_result = syn::parse_file(code).map_err(|e| e.to_string())?;
let mut unique_ctxs: HashMap<GenericArgument, String> = HashMap::new();
let mut unique_ctxs: HashMap<GenericArgument, Ident> = HashMap::new();
for (ix, ctx) in ix_ctx_pairs {
let mut ctx_ident = None;
let ix_name = ix.to_string().to_upper_camel_case();
Expand All @@ -82,39 +82,40 @@ fn get_snapshot_structs_and_impls(

// If ctx is in the HashMap, we do not need to generate deserialization code again, we can only create a type alias
match unique_ctxs.get(ctx) {
Some(unique_ix) => {
let ix_snapshot_name = format_ident!("{}Snapshot", ix_name);
let base_ix_snapshot_name = format_ident!("{}Snapshot", unique_ix);
Some(base_ix_snapshot_name) => {
let snapshot_alias_name = format_ident!("{}Snapshot", ix_name);
let type_alias =
quote! {pub type #ix_snapshot_name<'info> = #base_ix_snapshot_name<'info>;};
quote! {pub type #snapshot_alias_name<'info> = #base_ix_snapshot_name<'info>;};
type_aliases = format!("{}{}", type_aliases, type_alias.into_token_stream());
}
None => {
// recursively find the context struct and create a new version with wrapped fields into Option
if let Some(ctx) = find_ctx_struct(&parse_result.items, &ctx_ident) {
let fields_parsed = if let Fields::Named(f) = ctx.fields.clone() {
if let Some(ctx_struct_item) = find_ctx_struct(&parse_result.items, &ctx_ident) {
let fields_parsed = if let Fields::Named(f) = ctx_struct_item.fields.clone() {
let field_deser: ParseResult<Vec<AccountField>> =
f.named.iter().map(parse_account_field).collect();
field_deser
} else {
Err(ParseError::new(
ctx.fields.span(),
ctx_struct_item.fields.span(),
"Context struct parse errror.",
))
}
.map_err(|e| e.to_string())?;

let ix_snapshot_name = format_ident!("{}Snapshot", ix_name);
let wrapped_struct =
create_snapshot_struct(&ix_name, ctx, &fields_parsed).unwrap();
let deser_code = deserialize_ctx_struct_anchor(&ix_name, &fields_parsed)
.map_err(|e| e.to_string())?;
// let deser_code = deserialize_ctx_struct(ctx).unwrap();
create_snapshot_struct(&ix_snapshot_name, ctx_struct_item, &fields_parsed)
.unwrap();
let deser_code =
deserialize_ctx_struct_anchor(&ix_snapshot_name, &fields_parsed)
.map_err(|e| e.to_string())?;
structs = format!("{}{}", structs, wrapped_struct.into_token_stream());
impls = format!("{}{}", impls, deser_code.into_token_stream());
unique_ctxs.insert(ctx.clone(), ix_snapshot_name);
} else {
return Err(format!("The Context struct {} was not found", ctx_ident));
}
unique_ctxs.insert(ctx.clone(), ix_name);
}
};
}
Expand Down Expand Up @@ -206,11 +207,10 @@ fn is_optional(parsed_field: &AccountField) -> bool {

/// Creates new Snapshot struct from the context struct. Removes Box<> types.
fn create_snapshot_struct(
ix_name: &str,
snapshot_name: &Ident,
orig_struct: &ItemStruct,
parsed_fields: &[AccountField],
) -> Result<TokenStream, Box<dyn Error>> {
let struct_name = format_ident!("{}Snapshot", ix_name);
let wrapped_fields = match orig_struct.fields.clone() {
Fields::Named(named) => {
let field_wrappers =
Expand Down Expand Up @@ -270,7 +270,7 @@ fn create_snapshot_struct(

// Generate the new struct with Option-wrapped fields
let generated_struct: syn::ItemStruct = parse_quote! {
pub struct #struct_name<'info> #wrapped_fields
pub struct #snapshot_name<'info> #wrapped_fields
};

Ok(generated_struct.to_token_stream())
Expand All @@ -295,10 +295,9 @@ fn extract_inner_type(field_type: &Type) -> Option<&Type> {

/// Generates code to deserialize the snapshot structs.
fn deserialize_ctx_struct_anchor(
snapshot_struct: &str,
snapshot_name: &Ident,
parsed_fields: &[AccountField],
) -> Result<TokenStream, Box<dyn Error>> {
let impl_name = format_ident!("{}Snapshot", snapshot_struct);
let names_deser_pairs: Result<Vec<(TokenStream, TokenStream)>, _> = parsed_fields
.iter()
.map(|parsed_f| match parsed_f {
Expand Down Expand Up @@ -331,7 +330,7 @@ fn deserialize_ctx_struct_anchor(
let (names, fields_deser): (Vec<_>, Vec<_>) = names_deser_pairs?.iter().cloned().unzip();

let generated_deser_impl: syn::Item = parse_quote! {
impl<'info> #impl_name<'info> {
impl<'info> #snapshot_name<'info> {
pub fn deserialize_option(
accounts: &'info mut [Option<AccountInfo<'info>>],
) -> core::result::Result<Self, FuzzingError> {
Expand Down

0 comments on commit 55ab963

Please sign in to comment.