From 7538af12d8c3d08ab4381bbda0b2978422e5e36b Mon Sep 17 00:00:00 2001 From: rzvxa <3788964+rzvxa@users.noreply.github.com> Date: Tue, 2 Jul 2024 10:18:45 +0000 Subject: [PATCH] feat(ast_codegen): add visit generator (#3954) ~~The generated code is only here for the sake of my own comparison (instead of manually keeping a backup of the old generated file). I would clean this up as soon as it is ready, submit some parts of it as the down stack, and stack the actual generated code on top of this. So please don't let the huge diff distract you, It won't have many conflicts since almost all of these are the generated visit code, which is completely contained to its own module(other than some minor renaming refactors).~~ The order of function definitions is a bit different, I've used a depth-first search, We can switch to a breadth-first one to align functions more closely to the original. --- Cargo.lock | 1 + crates/oxc_ast/src/ast/js.rs | 15 +- crates/oxc_ast/src/ast/ts.rs | 2 +- crates/oxc_ast_macros/src/lib.rs | 2 +- tasks/ast_codegen/Cargo.toml | 1 + tasks/ast_codegen/src/defs.rs | 4 +- tasks/ast_codegen/src/generators/ast_kind.rs | 3 +- tasks/ast_codegen/src/generators/mod.rs | 2 + tasks/ast_codegen/src/generators/visit.rs | 668 +++++++++++++++++++ tasks/ast_codegen/src/linker.rs | 16 +- tasks/ast_codegen/src/main.rs | 15 +- tasks/ast_codegen/src/schema.rs | 2 +- tasks/ast_codegen/src/util.rs | 154 +++++ 13 files changed, 874 insertions(+), 11 deletions(-) create mode 100644 tasks/ast_codegen/src/generators/visit.rs create mode 100644 tasks/ast_codegen/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index 3243a33ccb0a5..ac10ae4bed122 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1337,6 +1337,7 @@ dependencies = [ name = "oxc_ast_codegen" version = "0.0.0" dependencies = [ + "convert_case", "itertools 0.13.0", "lazy_static", "prettyplease", diff --git a/crates/oxc_ast/src/ast/js.rs b/crates/oxc_ast/src/ast/js.rs index 075b0578bdd00..1d4f95314dbce 100644 --- a/crates/oxc_ast/src/ast/js.rs +++ b/crates/oxc_ast/src/ast/js.rs @@ -78,6 +78,7 @@ pub enum Expression<'a> { ChainExpression(Box<'a, ChainExpression<'a>>) = 16, ClassExpression(Box<'a, Class<'a>>) = 17, ConditionalExpression(Box<'a, ConditionalExpression<'a>>) = 18, + #[visit_args(flags = None)] FunctionExpression(Box<'a, Function<'a>>) = 19, ImportExpression(Box<'a, ImportExpression<'a>>) = 20, LogicalExpression(Box<'a, LogicalExpression<'a>>) = 21, @@ -250,6 +251,7 @@ pub enum ArrayExpressionElement<'a> { /// Elision(Elision) = 65, // `Expression` variants added here by `inherit_variants!` macro + // TODO: support for attributes syntax here so we can use `#[visit_as(ExpressionArrayElement)]` @inherit Expression } } @@ -967,6 +969,7 @@ pub struct BlockStatement<'a> { #[cfg_attr(feature = "serialize", serde(untagged))] pub enum Declaration<'a> { VariableDeclaration(Box<'a, VariableDeclaration<'a>>) = 32, + #[visit_args(flags = None)] FunctionDeclaration(Box<'a, Function<'a>>) = 33, ClassDeclaration(Box<'a, Class<'a>>) = 34, UsingDeclaration(Box<'a, UsingDeclaration<'a>>) = 35, @@ -1291,6 +1294,7 @@ pub struct TryStatement<'a> { pub span: Span, pub block: Box<'a, BlockStatement<'a>>, pub handler: Option>>, + #[visit_as(FinallyClause)] pub finalizer: Option>>, } @@ -1432,7 +1436,7 @@ pub struct BindingRestElement<'a> { #[visited_node] #[scope( // TODO: `ScopeFlags::Function` is not correct if this is a `MethodDefinition` - flags(ScopeFlags::Function), + flags(flags.unwrap_or(ScopeFlags::empty()) | ScopeFlags::Function), strict_if(self.body.as_ref().is_some_and(|body| body.has_use_strict_directive())), )] #[derive(Debug)] @@ -1586,6 +1590,7 @@ pub struct Class<'a> { pub decorators: Vec<'a, Decorator<'a>>, #[scope(enter_before)] pub id: Option>, + #[visit_as(ClassHeritage)] pub super_class: Option>, pub body: Box<'a, ClassBody<'a>>, pub type_parameters: Option>>, @@ -1635,6 +1640,12 @@ pub struct MethodDefinition<'a> { pub span: Span, pub decorators: Vec<'a, Decorator<'a>>, pub key: PropertyKey<'a>, + #[visit_args(flags = Some(match self.kind { + MethodDefinitionKind::Get => ScopeFlags::GetAccessor, + MethodDefinitionKind::Set => ScopeFlags::SetAccessor, + MethodDefinitionKind::Constructor => ScopeFlags::Constructor, + MethodDefinitionKind::Method => ScopeFlags::empty(), + }))] pub value: Box<'a, Function<'a>>, // FunctionExpression pub kind: MethodDefinitionKind, pub computed: bool, @@ -1946,9 +1957,11 @@ inherit_variants! { #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(untagged))] pub enum ExportDefaultDeclarationKind<'a> { + #[visit_args(flags = None)] FunctionDeclaration(Box<'a, Function<'a>>) = 64, ClassDeclaration(Box<'a, Class<'a>>) = 65, + #[visit(ignore)] TSInterfaceDeclaration(Box<'a, TSInterfaceDeclaration<'a>>) = 66, // `Expression` variants added here by `inherit_variants!` macro diff --git a/crates/oxc_ast/src/ast/ts.rs b/crates/oxc_ast/src/ast/ts.rs index af9683a9981d8..d174cd5638050 100644 --- a/crates/oxc_ast/src/ast/ts.rs +++ b/crates/oxc_ast/src/ast/ts.rs @@ -802,7 +802,7 @@ pub enum TSTypePredicateName<'a> { #[visited_node] #[scope( flags(ScopeFlags::TsModuleBlock), - strict_if(self.body.as_ref().is_some_and(|body| body.is_strict())), + strict_if(self.body.as_ref().is_some_and(TSModuleDeclarationBody::is_strict)), )] #[derive(Debug)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] diff --git a/crates/oxc_ast_macros/src/lib.rs b/crates/oxc_ast_macros/src/lib.rs index 37154f9e53877..5f4f54d1be27f 100644 --- a/crates/oxc_ast_macros/src/lib.rs +++ b/crates/oxc_ast_macros/src/lib.rs @@ -22,7 +22,7 @@ pub fn visited_node(_args: TokenStream, input: TokenStream) -> TokenStream { /// Dummy derive macro for a non-existent trait `VisitedNode`. /// /// Does not generate any code, only purpose is to allow using `#[scope]` attr in the type def. -#[proc_macro_derive(VisitedNode, attributes(scope))] +#[proc_macro_derive(VisitedNode, attributes(scope, visit, visit_as, visit_args))] pub fn visited_node_derive(_item: TokenStream) -> TokenStream { TokenStream::new() } diff --git a/tasks/ast_codegen/Cargo.toml b/tasks/ast_codegen/Cargo.toml index 441d7db8d97ab..012c27454bd2c 100644 --- a/tasks/ast_codegen/Cargo.toml +++ b/tasks/ast_codegen/Cargo.toml @@ -30,3 +30,4 @@ serde = { workspace = true, features = ["derive"] } regex = { workspace = true } prettyplease = { workspace = true } lazy_static = { workspace = true } +convert_case = { workspace = true } diff --git a/tasks/ast_codegen/src/defs.rs b/tasks/ast_codegen/src/defs.rs index 1733d8e7503d5..ad7b5cb69706e 100644 --- a/tasks/ast_codegen/src/defs.rs +++ b/tasks/ast_codegen/src/defs.rs @@ -1,5 +1,5 @@ use super::{REnum, RStruct, RType}; -use crate::{schema::Inherit, TypeName}; +use crate::{schema::Inherit, util::TypeExt, TypeName}; use quote::ToTokens; use serde::Serialize; @@ -95,7 +95,7 @@ impl From<&Inherit> for EnumInheritDef { fn from(inherit: &Inherit) -> Self { match inherit { Inherit::Linked { super_, variants } => Self { - super_name: super_.into(), + super_name: super_.get_ident().as_ident().unwrap().to_string(), variants: variants.iter().map(Into::into).collect(), }, Inherit::Unlinked(_) => { diff --git a/tasks/ast_codegen/src/generators/ast_kind.rs b/tasks/ast_codegen/src/generators/ast_kind.rs index c0cea6da76478..e2696003ead73 100644 --- a/tasks/ast_codegen/src/generators/ast_kind.rs +++ b/tasks/ast_codegen/src/generators/ast_kind.rs @@ -8,7 +8,7 @@ use super::generated_header; pub struct AstKindGenerator; -const BLACK_LIST: [&str; 69] = [ +pub const BLACK_LIST: [&str; 69] = [ "Expression", "ObjectPropertyKind", "TemplateElement", @@ -102,6 +102,7 @@ impl Generator for AstKindGenerator { let have_kinds: Vec<(Ident, Type)> = ctx .ty_table .iter() + .filter(|it| it.borrow().visitable()) .filter_map(|maybe_kind| match &*maybe_kind.borrow() { kind @ (RType::Enum(_) | RType::Struct(_)) if kind.visitable() => { let ident = kind.ident().unwrap().clone(); diff --git a/tasks/ast_codegen/src/generators/mod.rs b/tasks/ast_codegen/src/generators/mod.rs index 01959fc2a0c87..d13a656ad7ce6 100644 --- a/tasks/ast_codegen/src/generators/mod.rs +++ b/tasks/ast_codegen/src/generators/mod.rs @@ -1,6 +1,7 @@ mod ast; mod ast_kind; mod impl_get_span; +mod visit; /// Inserts a newline in the `TokenStream`. #[allow(unused)] @@ -42,3 +43,4 @@ pub(crate) use insert; pub use ast::AstGenerator; pub use ast_kind::AstKindGenerator; pub use impl_get_span::ImplGetSpanGenerator; +pub use visit::VisitGenerator; diff --git a/tasks/ast_codegen/src/generators/visit.rs b/tasks/ast_codegen/src/generators/visit.rs new file mode 100644 index 0000000000000..a8e24e18f84d9 --- /dev/null +++ b/tasks/ast_codegen/src/generators/visit.rs @@ -0,0 +1,668 @@ +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + iter::Cloned, +}; + +use convert_case::{Case, Casing}; +use itertools::Itertools; +use proc_macro2::{TokenStream, TokenTree}; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parenthesized, + parse::{Parse, ParseStream}, + parse2, parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token::Paren, + Arm, Attribute, Expr, Field, GenericArgument, Ident, Meta, MetaNameValue, Path, PathArguments, + Token, Type, Variant, +}; + +use crate::{ + generators::{ast_kind::BLACK_LIST as KIND_BLACK_LIST, insert}, + schema::{Inherit, REnum, RStruct, RType}, + util::{StrExt, TokenStreamExt, TypeExt, TypeIdentResult}, + CodegenCtx, Generator, GeneratorOutput, Result, TypeRef, +}; + +use super::generated_header; + +pub struct VisitGenerator; + +impl Generator for VisitGenerator { + fn name(&self) -> &'static str { + "VisitGenerator" + } + + fn generate(&mut self, ctx: &CodegenCtx) -> GeneratorOutput { + let visit = (String::from("visit"), generate_visit(ctx)); + + GeneratorOutput::Many(HashMap::from_iter(vec![visit])) + } +} + +fn generate_visit(ctx: &CodegenCtx) -> TokenStream { + let header = generated_header!(); + // we evaluate it outside of quote to take advantage of expression evaluation + // otherwise the `\n\` wouldn't work! + let file_docs = insert! {"\ + //! Visitor Pattern\n\ + //!\n\ + //! See:\n\ + //! * [visitor pattern](https://rust-unofficial.github.io/patterns/patterns/behavioural/visitor.html)\n\ + //! * [rustc visitor](https://github.com/rust-lang/rust/blob/master/compiler/rustc_ast/src/visit.rs)\n\ + "}; + + let (visits, walks) = VisitBuilder::new(ctx).build(); + + quote! { + #header + #file_docs + insert!("#![allow(clippy::self_named_module_files, clippy::semicolon_if_nothing_returned, clippy::match_wildcard_for_single_variants)]"); + + endl!(); + + use oxc_allocator::Vec; + use oxc_syntax::scope::ScopeFlags; + + endl!(); + + use crate::{ast::*, ast_kind::AstKind}; + + endl!(); + + use walk::*; + + endl!(); + + /// Syntax tree traversal + pub trait Visit<'a>: Sized { + #[allow(unused_variables)] + fn enter_node(&mut self, kind: AstKind<'a>) {} + #[allow(unused_variables)] + fn leave_node(&mut self, kind: AstKind<'a>) {} + + endl!(); + + #[allow(unused_variables)] + fn enter_scope(&mut self, flags: ScopeFlags) {} + fn leave_scope(&mut self) {} + + endl!(); + + fn alloc(&self, t: &T) -> &'a T { + insert!("// SAFETY:"); + insert!("// This should be safe as long as `src` is an reference from the allocator."); + insert!("// But honestly, I'm not really sure if this is safe."); + #[allow(unsafe_code)] + unsafe { + std::mem::transmute(t) + } + } + + #(#visits)* + } + + endl!(); + + pub mod walk { + use super::*; + + #(#walks)* + + } + } +} + +struct VisitBuilder<'a> { + ctx: &'a CodegenCtx, + visits: Vec, + walks: Vec, + cache: HashMap>; 2]>, +} + +impl<'a> VisitBuilder<'a> { + fn new(ctx: &'a CodegenCtx) -> Self { + Self { ctx, visits: Vec::new(), walks: Vec::new(), cache: HashMap::new() } + } + + fn build(mut self) -> (/* visits */ Vec, /* walks */ Vec) { + let program = { + let types: Vec<&TypeRef> = + self.ctx.ty_table.iter().filter(|it| it.borrow().visitable()).collect_vec(); + TypeRef::clone( + types + .iter() + .find(|it| it.borrow().ident().is_some_and(|ident| ident == "Program")) + .expect("Couldn't find the `Program` type!"), + ) + }; + + self.get_visitor(&program, false, None); + (self.visits, self.walks) + } + + fn get_visitor( + &mut self, + ty: &TypeRef, + collection: bool, + visit_as: Option<&Ident>, + ) -> Cow<'a, Ident> { + let cache_ix = usize::from(collection); + let (ident, as_type) = { + let ty = ty.borrow(); + debug_assert!(ty.visitable(), "{ty:?}"); + + let ident = ty.ident().unwrap(); + let as_type = ty.as_type().unwrap(); + + let ident = visit_as.unwrap_or(ident); + + (ident.clone(), if collection { parse_quote!(Vec<'a, #as_type>) } else { as_type }) + }; + + // is it already generated? + if let Some(cached) = self.cache.get(&ident) { + if let Some(cached) = &cached[cache_ix] { + return Cow::clone(cached); + } + } + + let ident_snake = { + let it = ident.to_string().to_case(Case::Snake); + let it = if collection { + // edge case for `Vec` to avoid conflicts with `FormalParameters` + // which both would generate the same name: `visit_formal_parameters`. + // and edge case for `Vec` to avoid conflicts with + // `TSImportAttributes` which both would generate the same name: `visit_formal_parameters`. + if matches!(it.as_str(), "formal_parameter" | "ts_import_attribute") { + let mut it = it; + it.push_str("_list"); + it + } else { + it.to_plural() + } + } else { + it + }; + format_ident!("{it}") + }; + + let as_param_type = quote!(&#as_type); + let (extra_params, extra_args) = if ident == "Function" { + (quote!(, flags: Option,), quote!(, flags)) + } else { + (TokenStream::default(), TokenStream::default()) + }; + + let visit_name = { + let visit_name = format_ident!("visit_{}", ident_snake); + if !self.cache.contains_key(&ident) { + debug_assert!(self.cache.insert(ident.clone(), [None, None]).is_none()); + } + let cached = self.cache.get_mut(&ident).unwrap(); + assert!(cached[cache_ix].replace(Cow::Owned(visit_name)).is_none()); + Cow::clone(cached[cache_ix].as_ref().unwrap()) + }; + + let walk_name = format_ident!("walk_{}", ident_snake); + + self.visits.push(quote! { + endl!(); + #[inline] + fn #visit_name (&mut self, it: #as_param_type #extra_params) { + #walk_name(self, it #extra_args); + } + }); + + // We push an empty walk first, because we evaluate - and generate - each walk as we go, + // This would let us to maintain the order of first visit. + let this_walker = self.walks.len(); + self.walks.push(TokenStream::default()); + + let walk_body = if collection { + let singular_visit = self.get_visitor(ty, false, None); + quote! { + for el in it { + visitor.#singular_visit(el); + } + } + } else { + match &*ty.borrow() { + RType::Enum(enum_) => self.generate_enum_walk(enum_, visit_as), + RType::Struct(struct_) => self.generate_struct_walk(struct_, visit_as), + _ => panic!(), + } + }; + + // replace the placeholder walker with the actual one! + self.walks[this_walker] = quote! { + endl!(); + pub fn #walk_name <'a, V: Visit<'a>>(visitor: &mut V, it: #as_param_type #extra_params) { + #walk_body + } + }; + + visit_name + } + + fn generate_enum_walk(&mut self, enum_: &REnum, visit_as: Option<&Ident>) -> TokenStream { + let ident = enum_.ident(); + let mut non_exhaustive = false; + let variants_matches = enum_ + .item + .variants + .iter() + .filter(|it| !it.attrs.iter().any(|a| a.path().is_ident("inherit"))) + .filter(|it| { + if it.attrs.iter().any(|a| { + a.path().is_ident("visit") + && a.meta + .require_list() + .unwrap() + .parse_args::() + .unwrap() + .is_ident("ignore") + }) { + // We are ignoring some variants so the match is no longer exhaustive. + non_exhaustive = true; + false + } else { + true + } + }) + .filter_map(|it| { + let typ = it + .fields + .iter() + .exactly_one() + .map(|f| &f.ty) + .map_err(|_| "We only support visited enum nodes with exactly one field!") + .unwrap(); + let variant_name = &it.ident; + let typ = self.ctx.find(&typ.get_ident().inner_ident().to_string())?; + let borrowed = typ.borrow(); + let visitable = borrowed.visitable(); + if visitable { + let visit = self.get_visitor(&typ, false, None); + let (args_def, args) = it + .attrs + .iter() + .find(|it| it.path().is_ident("visit_args")) + .map(|it| it.parse_args_with(VisitArgs::parse)) + .map(|it| { + it.into_iter() + .flatten() + .fold((Vec::new(), Vec::new()), Self::visit_args_fold) + }) + .unwrap_or_default(); + let body = quote!(visitor.#visit(it #(#args)*)); + let body = if args_def.is_empty() { + body + } else { + // if we have args wrap the result in a block to prevent ident clashes. + quote! {{ + #(#args_def)* + #body + }} + }; + Some(quote!(#ident::#variant_name(it) => #body)) + } else { + None + } + }) + .collect_vec(); + + let inherit_matches = enum_.meta.inherits.iter().filter_map(|it| { + let Inherit::Linked { super_, .. } = it else { panic!("Unresolved inheritance!") }; + let type_name = super_.get_ident().as_ident().unwrap().to_string(); + let typ = self.ctx.find(&type_name)?; + if typ.borrow().visitable() { + let snake_name = type_name.to_case(Case::Snake); + let match_macro = format_ident!("match_{snake_name}"); + let match_macro = quote!(#match_macro!(#ident)); + // HACK: edge case till we get attributes to work with inheritance. + let visit_as = if ident == "ArrayExpressionElement" + && super_.get_ident().inner_ident() == "Expression" + { + Some(format_ident!("ExpressionArrayElement")) + } else { + None + }; + let to_child = format_ident!("to_{snake_name}"); + let visit = self.get_visitor(&typ, false, visit_as.as_ref()); + Some(quote!(#match_macro => visitor.#visit(it.#to_child()))) + } else { + None + } + }); + + let matches = variants_matches.into_iter().chain(inherit_matches).collect_vec(); + + let with_node_events = |tk| { + let ident = visit_as.unwrap_or(ident); + if KIND_BLACK_LIST.contains(&ident.to_string().as_str()) { + tk + } else { + quote! { + let kind = AstKind::#ident(visitor.alloc(it)); + visitor.enter_node(kind); + #tk + visitor.leave_node(kind); + } + } + }; + let non_exhaustive = if non_exhaustive { Some(quote!(,_ => {})) } else { None }; + with_node_events(quote!(match it { #(#matches),* #non_exhaustive })) + } + + fn generate_struct_walk(&mut self, struct_: &RStruct, visit_as: Option<&Ident>) -> TokenStream { + let ident = visit_as.unwrap_or_else(|| struct_.ident()); + let scope_attr = struct_.item.attrs.iter().find(|it| it.path().is_ident("scope")); + let (scope_enter, scope_leave) = scope_attr + .map(parse_as_scope) + .transpose() + .unwrap() + .map_or_else(Default::default, |scope_args| { + let cond = scope_args.r#if.map(|cond| { + let cond = cond.to_token_stream().replace_ident("self", &format_ident!("it")); + quote!(let scope_events_cond = #cond;) + }); + let maybe_conditional = |tk: TokenStream| { + if cond.is_some() { + quote! { + if scope_events_cond { + #tk + } + } + } else { + tk + } + }; + let flags = scope_args + .flags + .map_or_else(|| quote!(ScopeFlags::empty()), |it| it.to_token_stream()); + let args = if let Some(strict_if) = scope_args.strict_if { + let strict_if = + strict_if.to_token_stream().replace_ident("self", &format_ident!("it")); + quote! {{ + let mut flags = #flags; + if #strict_if { + flags |= ScopeFlags::StrictMode; + } + flags + }} + } else { + flags + }; + let mut enter = cond.as_ref().into_token_stream(); + enter.extend(maybe_conditional(quote!(visitor.enter_scope(#args);))); + let leave = maybe_conditional(quote!(visitor.leave_scope();)); + (Some(enter), Some(leave)) + }); + let mut entered_scope = false; + let fields_visits: Vec = struct_ + .item + .fields + .iter() + .filter_map(|it| { + let (typ, typ_wrapper) = self.analyze_type(&it.ty)?; + let visit_as: Option = + it.attrs.iter().find(|it| it.path().is_ident("visit_as")).map(|it| { + match &it.meta { + Meta::List(meta) => { + parse2(meta.tokens.clone()).expect("wrong `visit_as` input!") + } + _ => panic!("wrong use of `visit_as`!"), + } + }); + // TODO: make sure it is `#[scope(enter_before)]` + let have_enter_scope = it.attrs.iter().any(|it| it.path().is_ident("scope")); + let args = it.attrs.iter().find(|it| it.meta.path().is_ident("visit_args")); + let (args_def, args) = args + .map(|it| it.parse_args_with(VisitArgs::parse)) + .map(|it| { + it.into_iter() + .flatten() + .fold((Vec::new(), Vec::new()), Self::visit_args_fold) + }) + .unwrap_or_default(); + let visit = self.get_visitor( + &typ, + matches!( + typ_wrapper, + TypeWrapper::Vec | TypeWrapper::VecBox | TypeWrapper::OptVec + ), + visit_as.as_ref(), + ); + let name = it.ident.as_ref().expect("expected named fields!"); + let mut result = match typ_wrapper { + TypeWrapper::Opt | TypeWrapper::OptBox | TypeWrapper::OptVec => quote! { + if let Some(ref #name) = it.#name { + visitor.#visit(#name #(#args)*); + } + }, + TypeWrapper::VecOpt => quote! { + for #name in (&it.#name).into_iter().flatten() { + visitor.#visit(#name #(#args)*); + } + }, + _ => quote! { + visitor.#visit(&it.#name #(#args)*); + }, + }; + if have_enter_scope { + assert!(!entered_scope); + result = quote! { + #scope_enter + #result + }; + entered_scope = true; + } + + if args_def.is_empty() { + Some(result) + } else { + // if we have args wrap the result in a block to prevent ident clashes. + Some(quote! {{ + #(#args_def)* + #result + }}) + } + }) + .collect(); + + let body = if KIND_BLACK_LIST.contains(&ident.to_string().as_str()) { + let unused = + if fields_visits.is_empty() { Some(quote!(let _ = (visitor, it);)) } else { None }; + quote! { + insert!("// NOTE: AstKind doesn't exists!"); + #(#fields_visits)* + #unused + } + } else { + quote! { + let kind = AstKind::#ident(visitor.alloc(it)); + visitor.enter_node(kind); + #(#fields_visits)* + visitor.leave_node(kind); + } + }; + + match (scope_enter, scope_leave, entered_scope) { + (_, Some(leave), true) => quote! { + #body + #leave + }, + (Some(enter), Some(leave), false) => quote! { + #enter + #body + #leave + }, + _ => body, + } + } + + fn analyze_type(&self, ty: &Type) -> Option<(TypeRef, TypeWrapper)> { + fn analyze<'a>(res: &'a TypeIdentResult) -> Option<(&'a Ident, TypeWrapper)> { + let mut wrapper = TypeWrapper::None; + let ident = match res { + TypeIdentResult::Ident(inner) => inner, + TypeIdentResult::Box(inner) => { + wrapper = TypeWrapper::Box; + let (inner, inner_kind) = analyze(inner)?; + assert!(inner_kind == TypeWrapper::None,); + inner + } + TypeIdentResult::Vec(inner) => { + wrapper = TypeWrapper::Vec; + let (inner, inner_kind) = analyze(inner)?; + if inner_kind == TypeWrapper::Opt { + wrapper = TypeWrapper::VecOpt; + } else if inner_kind != TypeWrapper::None { + panic!(); + } + inner + } + TypeIdentResult::Option(inner) => { + wrapper = TypeWrapper::Opt; + let (inner, inner_kind) = analyze(inner)?; + if inner_kind == TypeWrapper::Vec { + wrapper = TypeWrapper::OptVec; + } else if inner_kind == TypeWrapper::Box { + wrapper = TypeWrapper::OptBox; + } else if inner_kind != TypeWrapper::None { + panic!(); + } + inner + } + TypeIdentResult::Reference(_) => return None, + }; + Some((ident, wrapper)) + } + let type_ident = ty.get_ident(); + let (type_ident, wrapper) = analyze(&type_ident)?; + + let type_ref = self.ctx.find(&type_ident.to_string())?; + if type_ref.borrow().visitable() { + Some((type_ref, wrapper)) + } else { + None + } + } + + fn visit_args_fold( + mut accumulator: (Vec, Vec), + arg: VisitArg, + ) -> (Vec, Vec) { + let VisitArg { ident: id, value: val } = arg; + let val = val.to_token_stream().replace_ident("self", &format_ident!("it")); + accumulator.0.push(quote!(let #id = #val;)); + accumulator.1.push(quote!(, #id)); + accumulator + } +} + +#[derive(PartialEq)] +enum TypeWrapper { + None, + Box, + Vec, + Opt, + VecBox, + VecOpt, + OptBox, + OptVec, +} + +#[derive(Debug)] +struct VisitArgs(Punctuated); + +impl IntoIterator for VisitArgs { + type Item = VisitArg; + type IntoIter = syn::punctuated::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[derive(Debug)] +struct VisitArg { + ident: Ident, + value: Expr, +} + +#[derive(Debug, Default)] +struct ScopeArgs { + r#if: Option, + flags: Option, + strict_if: Option, +} + +impl Parse for VisitArgs { + fn parse(input: ParseStream) -> std::result::Result { + input.parse_terminated(VisitArg::parse, Token![,]).map(Self) + } +} + +impl Parse for VisitArg { + fn parse(input: ParseStream) -> std::result::Result { + let nv: MetaNameValue = input.parse()?; + Ok(Self { + ident: nv.path.get_ident().map_or_else( + || Err(syn::Error::new(nv.span(), "Invalid `visit_args` input!")), + |it| Ok(it.clone()), + )?, + value: nv.value, + }) + } +} + +impl Parse for ScopeArgs { + fn parse(input: ParseStream) -> std::result::Result { + fn parse(input: ParseStream) -> std::result::Result<(String, Expr), syn::Error> { + let ident = if let Ok(ident) = input.parse::() { + ident.to_string() + } else if input.parse::().is_ok() { + String::from("if") + } else { + return Err(syn::Error::new(input.span(), "Invalid `#[scope]` input.")); + }; + let content; + parenthesized!(content in input); + Ok((ident, content.parse()?)) + } + + let parsed = input.parse_terminated(parse, Token![,])?; + Ok(parsed.into_iter().fold(Self::default(), |mut acc, (ident, expr)| { + match ident.as_str() { + "if" => acc.r#if = Some(expr), + "flags" => acc.flags = Some(expr), + "strict_if" => acc.strict_if = Some(expr), + _ => {} + } + acc + })) + } +} + +fn parse_as_visit_args(attr: &Attribute) -> Vec<(Ident, TokenStream)> { + debug_assert!(attr.path().is_ident("visit_args")); + let mut result = Vec::new(); + let args: MetaNameValue = attr.parse_args().expect("Invalid `visit_args` input!"); + let ident = args.path.get_ident().unwrap().clone(); + let value = args.value.to_token_stream(); + result.push((ident, value)); + result +} + +fn parse_as_scope(attr: &Attribute) -> std::result::Result { + debug_assert!(attr.path().is_ident("scope")); + if matches!(attr.meta, Meta::Path(_)) { + // empty! + Ok(ScopeArgs::default()) + } else { + attr.parse_args_with(ScopeArgs::parse) + } +} diff --git a/tasks/ast_codegen/src/linker.rs b/tasks/ast_codegen/src/linker.rs index 41e46391f6950..6d78f411a96a1 100644 --- a/tasks/ast_codegen/src/linker.rs +++ b/tasks/ast_codegen/src/linker.rs @@ -1,5 +1,7 @@ use std::collections::VecDeque; +use syn::parse_quote; + use super::{CodegenCtx, Cow, Inherit, Itertools, RType, Result}; pub trait Linker<'a> { @@ -75,19 +77,27 @@ pub fn linker(ty: &mut RType, ctx: &CodegenCtx) -> Result { .map(|it| match it { Inherit::Unlinked(ref sup) => { let linkee = ctx.find(&Cow::Owned(sup.to_string())).unwrap(); - let variants = match &*linkee.borrow() { + let linkee = linkee.borrow(); + let inherit_value = format!(r#""{}""#, linkee.ident().unwrap()); + let variants = match &*linkee { RType::Enum(enum_) => { if enum_.meta.inherits.unresolved() { return Err(it); } - enum_.item.variants.clone() + enum_.item.variants.clone().into_iter().map(|mut v| { + v.attrs = vec![parse_quote!(#[inherit = #inherit_value])]; + v + }) } _ => { panic!("invalid inheritance, you can only inherit from enums and in enums.") } }; ty.item.variants.extend(variants.clone()); - Ok(Inherit::Linked { super_: sup.clone(), variants }) + Ok(Inherit::Linked { + super_: linkee.as_type().unwrap(), + variants: variants.collect(), + }) } Inherit::Linked { .. } => Ok(it), }) diff --git a/tasks/ast_codegen/src/main.rs b/tasks/ast_codegen/src/main.rs index 3d1ad02976edb..475e221ef2ad3 100644 --- a/tasks/ast_codegen/src/main.rs +++ b/tasks/ast_codegen/src/main.rs @@ -5,6 +5,7 @@ mod fmt; mod generators; mod linker; mod schema; +mod util; use std::{ borrow::Cow, @@ -22,7 +23,7 @@ use proc_macro2::TokenStream; use syn::parse_file; use defs::TypeDef; -use generators::{AstGenerator, AstKindGenerator}; +use generators::{AstGenerator, AstKindGenerator, VisitGenerator}; use linker::{linker, Linker}; use schema::{Inherit, Module, REnum, RStruct, RType, Schema}; @@ -191,6 +192,7 @@ fn main() -> std::result::Result<(), Box> { .with(AstGenerator) .with(AstKindGenerator) .with(ImplGetSpanGenerator) + .with(VisitGenerator) .generate()?; let output_dir = output_dir()?; @@ -218,6 +220,17 @@ fn main() -> std::result::Result<(), Box> { file.write_all(span_content.as_bytes())?; } + { + // write `visit.rs` file + let output = outputs[VisitGenerator.name()].as_many(); + let span_content = pprint(&output["visit"]); + + let path = format!("{output_dir}/visit.rs"); + let mut file = fs::File::create(path)?; + + file.write_all(span_content.as_bytes())?; + } + cargo_fmt(".")?; // let schema = serde_json::to_string_pretty(&schema).map_err(|e| e.to_string())?; diff --git a/tasks/ast_codegen/src/schema.rs b/tasks/ast_codegen/src/schema.rs index c478004600216..a2c3019075050 100644 --- a/tasks/ast_codegen/src/schema.rs +++ b/tasks/ast_codegen/src/schema.rs @@ -27,7 +27,7 @@ pub struct Definitions { #[derive(Debug, Clone)] pub enum Inherit { Unlinked(String), - Linked { super_: String, variants: Punctuated }, + Linked { super_: Type, variants: Punctuated }, } impl From for Inherit { diff --git a/tasks/ast_codegen/src/util.rs b/tasks/ast_codegen/src/util.rs new file mode 100644 index 0000000000000..7f94fb5908698 --- /dev/null +++ b/tasks/ast_codegen/src/util.rs @@ -0,0 +1,154 @@ +use itertools::Itertools; +use proc_macro2::{Group, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::{GenericArgument, Ident, PathArguments, Type, TypePath}; + +pub trait TokenStreamExt { + fn replace_ident(self, needle: &str, replace: &Ident) -> TokenStream; +} + +pub trait TypeExt { + fn get_ident(&self) -> TypeIdentResult; +} + +pub trait StrExt: AsRef { + /// Dead simple, just adds either `s` or `es` based on the last character. + /// doesn't handle things like `sh`, `x`, `z`, etc. It also creates wrong results when the word + /// ends with `y` but there is a preceding vowl similar to `toys`, + /// It WILL output the WRONG result `toies`! + /// As an edge case would output `children` for the input `child`. + fn to_plural(self) -> String; +} + +#[derive(Debug)] +pub enum TypeIdentResult<'a> { + Ident(&'a Ident), + Vec(Box>), + Box(Box>), + Option(Box>), + Reference(Box>), +} + +impl<'a> TypeIdentResult<'a> { + fn boxed(inner: Self) -> Self { + Self::Box(Box::new(inner)) + } + + fn vec(inner: Self) -> Self { + Self::Vec(Box::new(inner)) + } + + fn option(inner: Self) -> Self { + Self::Option(Box::new(inner)) + } + + fn reference(inner: Self) -> Self { + Self::Reference(Box::new(inner)) + } + + pub fn inner_ident(&self) -> &'a Ident { + match self { + Self::Ident(it) => it, + Self::Vec(it) | Self::Box(it) | Self::Option(it) | Self::Reference(it) => { + it.inner_ident() + } + } + } + + pub fn as_ident(&self) -> Option<&'a Ident> { + if let Self::Ident(it) = self { + Some(it) + } else { + None + } + } +} + +impl TypeExt for Type { + fn get_ident(&self) -> TypeIdentResult { + match self { + Type::Path(TypePath { path, .. }) => { + let seg1 = path.segments.first().unwrap(); + match &seg1.arguments { + PathArguments::None => TypeIdentResult::Ident(&seg1.ident), + PathArguments::AngleBracketed(it) => { + let args = &it.args.iter().collect_vec(); + assert!(args.len() < 3, "Max path arguments here is 2, eg `Box<'a, Adt>`"); + if let Some(second) = args.get(1) { + let GenericArgument::Type(second) = second else { panic!() }; + let inner = second.get_ident(); + if seg1.ident == "Box" { + TypeIdentResult::boxed(inner) + } else if seg1.ident == "Vec" { + TypeIdentResult::vec(inner) + } else { + panic!(); + } + } else { + match args.first() { + Some(GenericArgument::Type(it)) => { + let inner = it.get_ident(); + if seg1.ident == "Option" { + TypeIdentResult::option(inner) + } else { + inner + } + } + Some(GenericArgument::Lifetime(_)) => { + TypeIdentResult::Ident(&seg1.ident) + } + _ => panic!("unsupported type!"), + } + } + } + PathArguments::Parenthesized(_) => { + panic!("Parenthesized path arguments aren't supported!") + } + } + } + Type::Reference(typ) => TypeIdentResult::reference(typ.elem.get_ident()), + _ => panic!("Unsupported type."), + } + } +} + +impl> StrExt for T { + fn to_plural(self) -> String { + let txt = self.as_ref(); + if txt.is_empty() { + return String::default(); + } + + let mut txt = txt.to_string(); + if txt.ends_with("child") { + txt.push_str("ren"); + } else { + match txt.chars().last() { + Some('s') => { + txt.push_str("es"); + } + Some('y') => { + txt.pop(); + txt.push_str("ies"); + } + _ => txt.push('s'), + } + } + txt + } +} + +impl TokenStreamExt for TokenStream { + fn replace_ident(self, needle: &str, replace: &Ident) -> TokenStream { + self.into_iter() + .map(|it| match it { + TokenTree::Ident(ident) if ident == needle => replace.to_token_stream(), + TokenTree::Group(group) => { + Group::new(group.delimiter(), group.stream().replace_ident(needle, replace)) + .to_token_stream() + } + _ => it.to_token_stream(), + }) + .collect() + } +}