diff --git a/Cargo.lock b/Cargo.lock index 3243a33ccb0a5..ac10ae4bed122 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1337,6 +1337,7 @@ dependencies = [ name = "oxc_ast_codegen" version = "0.0.0" dependencies = [ + "convert_case", "itertools 0.13.0", "lazy_static", "prettyplease", diff --git a/crates/oxc_ast/src/ast/js.rs b/crates/oxc_ast/src/ast/js.rs index 075b0578bdd00..1d4f95314dbce 100644 --- a/crates/oxc_ast/src/ast/js.rs +++ b/crates/oxc_ast/src/ast/js.rs @@ -78,6 +78,7 @@ pub enum Expression<'a> { ChainExpression(Box<'a, ChainExpression<'a>>) = 16, ClassExpression(Box<'a, Class<'a>>) = 17, ConditionalExpression(Box<'a, ConditionalExpression<'a>>) = 18, + #[visit_args(flags = None)] FunctionExpression(Box<'a, Function<'a>>) = 19, ImportExpression(Box<'a, ImportExpression<'a>>) = 20, LogicalExpression(Box<'a, LogicalExpression<'a>>) = 21, @@ -250,6 +251,7 @@ pub enum ArrayExpressionElement<'a> { /// Elision(Elision) = 65, // `Expression` variants added here by `inherit_variants!` macro + // TODO: support for attributes syntax here so we can use `#[visit_as(ExpressionArrayElement)]` @inherit Expression } } @@ -967,6 +969,7 @@ pub struct BlockStatement<'a> { #[cfg_attr(feature = "serialize", serde(untagged))] pub enum Declaration<'a> { VariableDeclaration(Box<'a, VariableDeclaration<'a>>) = 32, + #[visit_args(flags = None)] FunctionDeclaration(Box<'a, Function<'a>>) = 33, ClassDeclaration(Box<'a, Class<'a>>) = 34, UsingDeclaration(Box<'a, UsingDeclaration<'a>>) = 35, @@ -1291,6 +1294,7 @@ pub struct TryStatement<'a> { pub span: Span, pub block: Box<'a, BlockStatement<'a>>, pub handler: Option>>, + #[visit_as(FinallyClause)] pub finalizer: Option>>, } @@ -1432,7 +1436,7 @@ pub struct BindingRestElement<'a> { #[visited_node] #[scope( // TODO: `ScopeFlags::Function` is not correct if this is a `MethodDefinition` - flags(ScopeFlags::Function), + flags(flags.unwrap_or(ScopeFlags::empty()) | ScopeFlags::Function), strict_if(self.body.as_ref().is_some_and(|body| body.has_use_strict_directive())), )] #[derive(Debug)] @@ -1586,6 +1590,7 @@ pub struct Class<'a> { pub decorators: Vec<'a, Decorator<'a>>, #[scope(enter_before)] pub id: Option>, + #[visit_as(ClassHeritage)] pub super_class: Option>, pub body: Box<'a, ClassBody<'a>>, pub type_parameters: Option>>, @@ -1635,6 +1640,12 @@ pub struct MethodDefinition<'a> { pub span: Span, pub decorators: Vec<'a, Decorator<'a>>, pub key: PropertyKey<'a>, + #[visit_args(flags = Some(match self.kind { + MethodDefinitionKind::Get => ScopeFlags::GetAccessor, + MethodDefinitionKind::Set => ScopeFlags::SetAccessor, + MethodDefinitionKind::Constructor => ScopeFlags::Constructor, + MethodDefinitionKind::Method => ScopeFlags::empty(), + }))] pub value: Box<'a, Function<'a>>, // FunctionExpression pub kind: MethodDefinitionKind, pub computed: bool, @@ -1946,9 +1957,11 @@ inherit_variants! { #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(untagged))] pub enum ExportDefaultDeclarationKind<'a> { + #[visit_args(flags = None)] FunctionDeclaration(Box<'a, Function<'a>>) = 64, ClassDeclaration(Box<'a, Class<'a>>) = 65, + #[visit(ignore)] TSInterfaceDeclaration(Box<'a, TSInterfaceDeclaration<'a>>) = 66, // `Expression` variants added here by `inherit_variants!` macro diff --git a/crates/oxc_ast/src/ast/ts.rs b/crates/oxc_ast/src/ast/ts.rs index af9683a9981d8..d174cd5638050 100644 --- a/crates/oxc_ast/src/ast/ts.rs +++ b/crates/oxc_ast/src/ast/ts.rs @@ -802,7 +802,7 @@ pub enum TSTypePredicateName<'a> { #[visited_node] #[scope( flags(ScopeFlags::TsModuleBlock), - strict_if(self.body.as_ref().is_some_and(|body| body.is_strict())), + strict_if(self.body.as_ref().is_some_and(TSModuleDeclarationBody::is_strict)), )] #[derive(Debug)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] diff --git a/crates/oxc_ast_macros/src/lib.rs b/crates/oxc_ast_macros/src/lib.rs index 37154f9e53877..5f4f54d1be27f 100644 --- a/crates/oxc_ast_macros/src/lib.rs +++ b/crates/oxc_ast_macros/src/lib.rs @@ -22,7 +22,7 @@ pub fn visited_node(_args: TokenStream, input: TokenStream) -> TokenStream { /// Dummy derive macro for a non-existent trait `VisitedNode`. /// /// Does not generate any code, only purpose is to allow using `#[scope]` attr in the type def. -#[proc_macro_derive(VisitedNode, attributes(scope))] +#[proc_macro_derive(VisitedNode, attributes(scope, visit, visit_as, visit_args))] pub fn visited_node_derive(_item: TokenStream) -> TokenStream { TokenStream::new() } diff --git a/tasks/ast_codegen/Cargo.toml b/tasks/ast_codegen/Cargo.toml index 441d7db8d97ab..012c27454bd2c 100644 --- a/tasks/ast_codegen/Cargo.toml +++ b/tasks/ast_codegen/Cargo.toml @@ -30,3 +30,4 @@ serde = { workspace = true, features = ["derive"] } regex = { workspace = true } prettyplease = { workspace = true } lazy_static = { workspace = true } +convert_case = { workspace = true } diff --git a/tasks/ast_codegen/src/defs.rs b/tasks/ast_codegen/src/defs.rs index 1733d8e7503d5..ad7b5cb69706e 100644 --- a/tasks/ast_codegen/src/defs.rs +++ b/tasks/ast_codegen/src/defs.rs @@ -1,5 +1,5 @@ use super::{REnum, RStruct, RType}; -use crate::{schema::Inherit, TypeName}; +use crate::{schema::Inherit, util::TypeExt, TypeName}; use quote::ToTokens; use serde::Serialize; @@ -95,7 +95,7 @@ impl From<&Inherit> for EnumInheritDef { fn from(inherit: &Inherit) -> Self { match inherit { Inherit::Linked { super_, variants } => Self { - super_name: super_.into(), + super_name: super_.get_ident().as_ident().unwrap().to_string(), variants: variants.iter().map(Into::into).collect(), }, Inherit::Unlinked(_) => { diff --git a/tasks/ast_codegen/src/generators/ast_kind.rs b/tasks/ast_codegen/src/generators/ast_kind.rs index c0cea6da76478..e2696003ead73 100644 --- a/tasks/ast_codegen/src/generators/ast_kind.rs +++ b/tasks/ast_codegen/src/generators/ast_kind.rs @@ -8,7 +8,7 @@ use super::generated_header; pub struct AstKindGenerator; -const BLACK_LIST: [&str; 69] = [ +pub const BLACK_LIST: [&str; 69] = [ "Expression", "ObjectPropertyKind", "TemplateElement", @@ -102,6 +102,7 @@ impl Generator for AstKindGenerator { let have_kinds: Vec<(Ident, Type)> = ctx .ty_table .iter() + .filter(|it| it.borrow().visitable()) .filter_map(|maybe_kind| match &*maybe_kind.borrow() { kind @ (RType::Enum(_) | RType::Struct(_)) if kind.visitable() => { let ident = kind.ident().unwrap().clone(); diff --git a/tasks/ast_codegen/src/generators/mod.rs b/tasks/ast_codegen/src/generators/mod.rs index 01959fc2a0c87..d13a656ad7ce6 100644 --- a/tasks/ast_codegen/src/generators/mod.rs +++ b/tasks/ast_codegen/src/generators/mod.rs @@ -1,6 +1,7 @@ mod ast; mod ast_kind; mod impl_get_span; +mod visit; /// Inserts a newline in the `TokenStream`. #[allow(unused)] @@ -42,3 +43,4 @@ pub(crate) use insert; pub use ast::AstGenerator; pub use ast_kind::AstKindGenerator; pub use impl_get_span::ImplGetSpanGenerator; +pub use visit::VisitGenerator; diff --git a/tasks/ast_codegen/src/generators/visit.rs b/tasks/ast_codegen/src/generators/visit.rs new file mode 100644 index 0000000000000..a8e24e18f84d9 --- /dev/null +++ b/tasks/ast_codegen/src/generators/visit.rs @@ -0,0 +1,668 @@ +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + iter::Cloned, +}; + +use convert_case::{Case, Casing}; +use itertools::Itertools; +use proc_macro2::{TokenStream, TokenTree}; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parenthesized, + parse::{Parse, ParseStream}, + parse2, parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token::Paren, + Arm, Attribute, Expr, Field, GenericArgument, Ident, Meta, MetaNameValue, Path, PathArguments, + Token, Type, Variant, +}; + +use crate::{ + generators::{ast_kind::BLACK_LIST as KIND_BLACK_LIST, insert}, + schema::{Inherit, REnum, RStruct, RType}, + util::{StrExt, TokenStreamExt, TypeExt, TypeIdentResult}, + CodegenCtx, Generator, GeneratorOutput, Result, TypeRef, +}; + +use super::generated_header; + +pub struct VisitGenerator; + +impl Generator for VisitGenerator { + fn name(&self) -> &'static str { + "VisitGenerator" + } + + fn generate(&mut self, ctx: &CodegenCtx) -> GeneratorOutput { + let visit = (String::from("visit"), generate_visit(ctx)); + + GeneratorOutput::Many(HashMap::from_iter(vec![visit])) + } +} + +fn generate_visit(ctx: &CodegenCtx) -> TokenStream { + let header = generated_header!(); + // we evaluate it outside of quote to take advantage of expression evaluation + // otherwise the `\n\` wouldn't work! + let file_docs = insert! {"\ + //! Visitor Pattern\n\ + //!\n\ + //! See:\n\ + //! * [visitor pattern](https://rust-unofficial.github.io/patterns/patterns/behavioural/visitor.html)\n\ + //! * [rustc visitor](https://github.com/rust-lang/rust/blob/master/compiler/rustc_ast/src/visit.rs)\n\ + "}; + + let (visits, walks) = VisitBuilder::new(ctx).build(); + + quote! { + #header + #file_docs + insert!("#![allow(clippy::self_named_module_files, clippy::semicolon_if_nothing_returned, clippy::match_wildcard_for_single_variants)]"); + + endl!(); + + use oxc_allocator::Vec; + use oxc_syntax::scope::ScopeFlags; + + endl!(); + + use crate::{ast::*, ast_kind::AstKind}; + + endl!(); + + use walk::*; + + endl!(); + + /// Syntax tree traversal + pub trait Visit<'a>: Sized { + #[allow(unused_variables)] + fn enter_node(&mut self, kind: AstKind<'a>) {} + #[allow(unused_variables)] + fn leave_node(&mut self, kind: AstKind<'a>) {} + + endl!(); + + #[allow(unused_variables)] + fn enter_scope(&mut self, flags: ScopeFlags) {} + fn leave_scope(&mut self) {} + + endl!(); + + fn alloc(&self, t: &T) -> &'a T { + insert!("// SAFETY:"); + insert!("// This should be safe as long as `src` is an reference from the allocator."); + insert!("// But honestly, I'm not really sure if this is safe."); + #[allow(unsafe_code)] + unsafe { + std::mem::transmute(t) + } + } + + #(#visits)* + } + + endl!(); + + pub mod walk { + use super::*; + + #(#walks)* + + } + } +} + +struct VisitBuilder<'a> { + ctx: &'a CodegenCtx, + visits: Vec, + walks: Vec, + cache: HashMap>; 2]>, +} + +impl<'a> VisitBuilder<'a> { + fn new(ctx: &'a CodegenCtx) -> Self { + Self { ctx, visits: Vec::new(), walks: Vec::new(), cache: HashMap::new() } + } + + fn build(mut self) -> (/* visits */ Vec, /* walks */ Vec) { + let program = { + let types: Vec<&TypeRef> = + self.ctx.ty_table.iter().filter(|it| it.borrow().visitable()).collect_vec(); + TypeRef::clone( + types + .iter() + .find(|it| it.borrow().ident().is_some_and(|ident| ident == "Program")) + .expect("Couldn't find the `Program` type!"), + ) + }; + + self.get_visitor(&program, false, None); + (self.visits, self.walks) + } + + fn get_visitor( + &mut self, + ty: &TypeRef, + collection: bool, + visit_as: Option<&Ident>, + ) -> Cow<'a, Ident> { + let cache_ix = usize::from(collection); + let (ident, as_type) = { + let ty = ty.borrow(); + debug_assert!(ty.visitable(), "{ty:?}"); + + let ident = ty.ident().unwrap(); + let as_type = ty.as_type().unwrap(); + + let ident = visit_as.unwrap_or(ident); + + (ident.clone(), if collection { parse_quote!(Vec<'a, #as_type>) } else { as_type }) + }; + + // is it already generated? + if let Some(cached) = self.cache.get(&ident) { + if let Some(cached) = &cached[cache_ix] { + return Cow::clone(cached); + } + } + + let ident_snake = { + let it = ident.to_string().to_case(Case::Snake); + let it = if collection { + // edge case for `Vec` to avoid conflicts with `FormalParameters` + // which both would generate the same name: `visit_formal_parameters`. + // and edge case for `Vec` to avoid conflicts with + // `TSImportAttributes` which both would generate the same name: `visit_formal_parameters`. + if matches!(it.as_str(), "formal_parameter" | "ts_import_attribute") { + let mut it = it; + it.push_str("_list"); + it + } else { + it.to_plural() + } + } else { + it + }; + format_ident!("{it}") + }; + + let as_param_type = quote!(&#as_type); + let (extra_params, extra_args) = if ident == "Function" { + (quote!(, flags: Option,), quote!(, flags)) + } else { + (TokenStream::default(), TokenStream::default()) + }; + + let visit_name = { + let visit_name = format_ident!("visit_{}", ident_snake); + if !self.cache.contains_key(&ident) { + debug_assert!(self.cache.insert(ident.clone(), [None, None]).is_none()); + } + let cached = self.cache.get_mut(&ident).unwrap(); + assert!(cached[cache_ix].replace(Cow::Owned(visit_name)).is_none()); + Cow::clone(cached[cache_ix].as_ref().unwrap()) + }; + + let walk_name = format_ident!("walk_{}", ident_snake); + + self.visits.push(quote! { + endl!(); + #[inline] + fn #visit_name (&mut self, it: #as_param_type #extra_params) { + #walk_name(self, it #extra_args); + } + }); + + // We push an empty walk first, because we evaluate - and generate - each walk as we go, + // This would let us to maintain the order of first visit. + let this_walker = self.walks.len(); + self.walks.push(TokenStream::default()); + + let walk_body = if collection { + let singular_visit = self.get_visitor(ty, false, None); + quote! { + for el in it { + visitor.#singular_visit(el); + } + } + } else { + match &*ty.borrow() { + RType::Enum(enum_) => self.generate_enum_walk(enum_, visit_as), + RType::Struct(struct_) => self.generate_struct_walk(struct_, visit_as), + _ => panic!(), + } + }; + + // replace the placeholder walker with the actual one! + self.walks[this_walker] = quote! { + endl!(); + pub fn #walk_name <'a, V: Visit<'a>>(visitor: &mut V, it: #as_param_type #extra_params) { + #walk_body + } + }; + + visit_name + } + + fn generate_enum_walk(&mut self, enum_: &REnum, visit_as: Option<&Ident>) -> TokenStream { + let ident = enum_.ident(); + let mut non_exhaustive = false; + let variants_matches = enum_ + .item + .variants + .iter() + .filter(|it| !it.attrs.iter().any(|a| a.path().is_ident("inherit"))) + .filter(|it| { + if it.attrs.iter().any(|a| { + a.path().is_ident("visit") + && a.meta + .require_list() + .unwrap() + .parse_args::() + .unwrap() + .is_ident("ignore") + }) { + // We are ignoring some variants so the match is no longer exhaustive. + non_exhaustive = true; + false + } else { + true + } + }) + .filter_map(|it| { + let typ = it + .fields + .iter() + .exactly_one() + .map(|f| &f.ty) + .map_err(|_| "We only support visited enum nodes with exactly one field!") + .unwrap(); + let variant_name = &it.ident; + let typ = self.ctx.find(&typ.get_ident().inner_ident().to_string())?; + let borrowed = typ.borrow(); + let visitable = borrowed.visitable(); + if visitable { + let visit = self.get_visitor(&typ, false, None); + let (args_def, args) = it + .attrs + .iter() + .find(|it| it.path().is_ident("visit_args")) + .map(|it| it.parse_args_with(VisitArgs::parse)) + .map(|it| { + it.into_iter() + .flatten() + .fold((Vec::new(), Vec::new()), Self::visit_args_fold) + }) + .unwrap_or_default(); + let body = quote!(visitor.#visit(it #(#args)*)); + let body = if args_def.is_empty() { + body + } else { + // if we have args wrap the result in a block to prevent ident clashes. + quote! {{ + #(#args_def)* + #body + }} + }; + Some(quote!(#ident::#variant_name(it) => #body)) + } else { + None + } + }) + .collect_vec(); + + let inherit_matches = enum_.meta.inherits.iter().filter_map(|it| { + let Inherit::Linked { super_, .. } = it else { panic!("Unresolved inheritance!") }; + let type_name = super_.get_ident().as_ident().unwrap().to_string(); + let typ = self.ctx.find(&type_name)?; + if typ.borrow().visitable() { + let snake_name = type_name.to_case(Case::Snake); + let match_macro = format_ident!("match_{snake_name}"); + let match_macro = quote!(#match_macro!(#ident)); + // HACK: edge case till we get attributes to work with inheritance. + let visit_as = if ident == "ArrayExpressionElement" + && super_.get_ident().inner_ident() == "Expression" + { + Some(format_ident!("ExpressionArrayElement")) + } else { + None + }; + let to_child = format_ident!("to_{snake_name}"); + let visit = self.get_visitor(&typ, false, visit_as.as_ref()); + Some(quote!(#match_macro => visitor.#visit(it.#to_child()))) + } else { + None + } + }); + + let matches = variants_matches.into_iter().chain(inherit_matches).collect_vec(); + + let with_node_events = |tk| { + let ident = visit_as.unwrap_or(ident); + if KIND_BLACK_LIST.contains(&ident.to_string().as_str()) { + tk + } else { + quote! { + let kind = AstKind::#ident(visitor.alloc(it)); + visitor.enter_node(kind); + #tk + visitor.leave_node(kind); + } + } + }; + let non_exhaustive = if non_exhaustive { Some(quote!(,_ => {})) } else { None }; + with_node_events(quote!(match it { #(#matches),* #non_exhaustive })) + } + + fn generate_struct_walk(&mut self, struct_: &RStruct, visit_as: Option<&Ident>) -> TokenStream { + let ident = visit_as.unwrap_or_else(|| struct_.ident()); + let scope_attr = struct_.item.attrs.iter().find(|it| it.path().is_ident("scope")); + let (scope_enter, scope_leave) = scope_attr + .map(parse_as_scope) + .transpose() + .unwrap() + .map_or_else(Default::default, |scope_args| { + let cond = scope_args.r#if.map(|cond| { + let cond = cond.to_token_stream().replace_ident("self", &format_ident!("it")); + quote!(let scope_events_cond = #cond;) + }); + let maybe_conditional = |tk: TokenStream| { + if cond.is_some() { + quote! { + if scope_events_cond { + #tk + } + } + } else { + tk + } + }; + let flags = scope_args + .flags + .map_or_else(|| quote!(ScopeFlags::empty()), |it| it.to_token_stream()); + let args = if let Some(strict_if) = scope_args.strict_if { + let strict_if = + strict_if.to_token_stream().replace_ident("self", &format_ident!("it")); + quote! {{ + let mut flags = #flags; + if #strict_if { + flags |= ScopeFlags::StrictMode; + } + flags + }} + } else { + flags + }; + let mut enter = cond.as_ref().into_token_stream(); + enter.extend(maybe_conditional(quote!(visitor.enter_scope(#args);))); + let leave = maybe_conditional(quote!(visitor.leave_scope();)); + (Some(enter), Some(leave)) + }); + let mut entered_scope = false; + let fields_visits: Vec = struct_ + .item + .fields + .iter() + .filter_map(|it| { + let (typ, typ_wrapper) = self.analyze_type(&it.ty)?; + let visit_as: Option = + it.attrs.iter().find(|it| it.path().is_ident("visit_as")).map(|it| { + match &it.meta { + Meta::List(meta) => { + parse2(meta.tokens.clone()).expect("wrong `visit_as` input!") + } + _ => panic!("wrong use of `visit_as`!"), + } + }); + // TODO: make sure it is `#[scope(enter_before)]` + let have_enter_scope = it.attrs.iter().any(|it| it.path().is_ident("scope")); + let args = it.attrs.iter().find(|it| it.meta.path().is_ident("visit_args")); + let (args_def, args) = args + .map(|it| it.parse_args_with(VisitArgs::parse)) + .map(|it| { + it.into_iter() + .flatten() + .fold((Vec::new(), Vec::new()), Self::visit_args_fold) + }) + .unwrap_or_default(); + let visit = self.get_visitor( + &typ, + matches!( + typ_wrapper, + TypeWrapper::Vec | TypeWrapper::VecBox | TypeWrapper::OptVec + ), + visit_as.as_ref(), + ); + let name = it.ident.as_ref().expect("expected named fields!"); + let mut result = match typ_wrapper { + TypeWrapper::Opt | TypeWrapper::OptBox | TypeWrapper::OptVec => quote! { + if let Some(ref #name) = it.#name { + visitor.#visit(#name #(#args)*); + } + }, + TypeWrapper::VecOpt => quote! { + for #name in (&it.#name).into_iter().flatten() { + visitor.#visit(#name #(#args)*); + } + }, + _ => quote! { + visitor.#visit(&it.#name #(#args)*); + }, + }; + if have_enter_scope { + assert!(!entered_scope); + result = quote! { + #scope_enter + #result + }; + entered_scope = true; + } + + if args_def.is_empty() { + Some(result) + } else { + // if we have args wrap the result in a block to prevent ident clashes. + Some(quote! {{ + #(#args_def)* + #result + }}) + } + }) + .collect(); + + let body = if KIND_BLACK_LIST.contains(&ident.to_string().as_str()) { + let unused = + if fields_visits.is_empty() { Some(quote!(let _ = (visitor, it);)) } else { None }; + quote! { + insert!("// NOTE: AstKind doesn't exists!"); + #(#fields_visits)* + #unused + } + } else { + quote! { + let kind = AstKind::#ident(visitor.alloc(it)); + visitor.enter_node(kind); + #(#fields_visits)* + visitor.leave_node(kind); + } + }; + + match (scope_enter, scope_leave, entered_scope) { + (_, Some(leave), true) => quote! { + #body + #leave + }, + (Some(enter), Some(leave), false) => quote! { + #enter + #body + #leave + }, + _ => body, + } + } + + fn analyze_type(&self, ty: &Type) -> Option<(TypeRef, TypeWrapper)> { + fn analyze<'a>(res: &'a TypeIdentResult) -> Option<(&'a Ident, TypeWrapper)> { + let mut wrapper = TypeWrapper::None; + let ident = match res { + TypeIdentResult::Ident(inner) => inner, + TypeIdentResult::Box(inner) => { + wrapper = TypeWrapper::Box; + let (inner, inner_kind) = analyze(inner)?; + assert!(inner_kind == TypeWrapper::None,); + inner + } + TypeIdentResult::Vec(inner) => { + wrapper = TypeWrapper::Vec; + let (inner, inner_kind) = analyze(inner)?; + if inner_kind == TypeWrapper::Opt { + wrapper = TypeWrapper::VecOpt; + } else if inner_kind != TypeWrapper::None { + panic!(); + } + inner + } + TypeIdentResult::Option(inner) => { + wrapper = TypeWrapper::Opt; + let (inner, inner_kind) = analyze(inner)?; + if inner_kind == TypeWrapper::Vec { + wrapper = TypeWrapper::OptVec; + } else if inner_kind == TypeWrapper::Box { + wrapper = TypeWrapper::OptBox; + } else if inner_kind != TypeWrapper::None { + panic!(); + } + inner + } + TypeIdentResult::Reference(_) => return None, + }; + Some((ident, wrapper)) + } + let type_ident = ty.get_ident(); + let (type_ident, wrapper) = analyze(&type_ident)?; + + let type_ref = self.ctx.find(&type_ident.to_string())?; + if type_ref.borrow().visitable() { + Some((type_ref, wrapper)) + } else { + None + } + } + + fn visit_args_fold( + mut accumulator: (Vec, Vec), + arg: VisitArg, + ) -> (Vec, Vec) { + let VisitArg { ident: id, value: val } = arg; + let val = val.to_token_stream().replace_ident("self", &format_ident!("it")); + accumulator.0.push(quote!(let #id = #val;)); + accumulator.1.push(quote!(, #id)); + accumulator + } +} + +#[derive(PartialEq)] +enum TypeWrapper { + None, + Box, + Vec, + Opt, + VecBox, + VecOpt, + OptBox, + OptVec, +} + +#[derive(Debug)] +struct VisitArgs(Punctuated); + +impl IntoIterator for VisitArgs { + type Item = VisitArg; + type IntoIter = syn::punctuated::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[derive(Debug)] +struct VisitArg { + ident: Ident, + value: Expr, +} + +#[derive(Debug, Default)] +struct ScopeArgs { + r#if: Option, + flags: Option, + strict_if: Option, +} + +impl Parse for VisitArgs { + fn parse(input: ParseStream) -> std::result::Result { + input.parse_terminated(VisitArg::parse, Token![,]).map(Self) + } +} + +impl Parse for VisitArg { + fn parse(input: ParseStream) -> std::result::Result { + let nv: MetaNameValue = input.parse()?; + Ok(Self { + ident: nv.path.get_ident().map_or_else( + || Err(syn::Error::new(nv.span(), "Invalid `visit_args` input!")), + |it| Ok(it.clone()), + )?, + value: nv.value, + }) + } +} + +impl Parse for ScopeArgs { + fn parse(input: ParseStream) -> std::result::Result { + fn parse(input: ParseStream) -> std::result::Result<(String, Expr), syn::Error> { + let ident = if let Ok(ident) = input.parse::() { + ident.to_string() + } else if input.parse::().is_ok() { + String::from("if") + } else { + return Err(syn::Error::new(input.span(), "Invalid `#[scope]` input.")); + }; + let content; + parenthesized!(content in input); + Ok((ident, content.parse()?)) + } + + let parsed = input.parse_terminated(parse, Token![,])?; + Ok(parsed.into_iter().fold(Self::default(), |mut acc, (ident, expr)| { + match ident.as_str() { + "if" => acc.r#if = Some(expr), + "flags" => acc.flags = Some(expr), + "strict_if" => acc.strict_if = Some(expr), + _ => {} + } + acc + })) + } +} + +fn parse_as_visit_args(attr: &Attribute) -> Vec<(Ident, TokenStream)> { + debug_assert!(attr.path().is_ident("visit_args")); + let mut result = Vec::new(); + let args: MetaNameValue = attr.parse_args().expect("Invalid `visit_args` input!"); + let ident = args.path.get_ident().unwrap().clone(); + let value = args.value.to_token_stream(); + result.push((ident, value)); + result +} + +fn parse_as_scope(attr: &Attribute) -> std::result::Result { + debug_assert!(attr.path().is_ident("scope")); + if matches!(attr.meta, Meta::Path(_)) { + // empty! + Ok(ScopeArgs::default()) + } else { + attr.parse_args_with(ScopeArgs::parse) + } +} diff --git a/tasks/ast_codegen/src/linker.rs b/tasks/ast_codegen/src/linker.rs index 41e46391f6950..6d78f411a96a1 100644 --- a/tasks/ast_codegen/src/linker.rs +++ b/tasks/ast_codegen/src/linker.rs @@ -1,5 +1,7 @@ use std::collections::VecDeque; +use syn::parse_quote; + use super::{CodegenCtx, Cow, Inherit, Itertools, RType, Result}; pub trait Linker<'a> { @@ -75,19 +77,27 @@ pub fn linker(ty: &mut RType, ctx: &CodegenCtx) -> Result { .map(|it| match it { Inherit::Unlinked(ref sup) => { let linkee = ctx.find(&Cow::Owned(sup.to_string())).unwrap(); - let variants = match &*linkee.borrow() { + let linkee = linkee.borrow(); + let inherit_value = format!(r#""{}""#, linkee.ident().unwrap()); + let variants = match &*linkee { RType::Enum(enum_) => { if enum_.meta.inherits.unresolved() { return Err(it); } - enum_.item.variants.clone() + enum_.item.variants.clone().into_iter().map(|mut v| { + v.attrs = vec![parse_quote!(#[inherit = #inherit_value])]; + v + }) } _ => { panic!("invalid inheritance, you can only inherit from enums and in enums.") } }; ty.item.variants.extend(variants.clone()); - Ok(Inherit::Linked { super_: sup.clone(), variants }) + Ok(Inherit::Linked { + super_: linkee.as_type().unwrap(), + variants: variants.collect(), + }) } Inherit::Linked { .. } => Ok(it), }) diff --git a/tasks/ast_codegen/src/main.rs b/tasks/ast_codegen/src/main.rs index 3d1ad02976edb..475e221ef2ad3 100644 --- a/tasks/ast_codegen/src/main.rs +++ b/tasks/ast_codegen/src/main.rs @@ -5,6 +5,7 @@ mod fmt; mod generators; mod linker; mod schema; +mod util; use std::{ borrow::Cow, @@ -22,7 +23,7 @@ use proc_macro2::TokenStream; use syn::parse_file; use defs::TypeDef; -use generators::{AstGenerator, AstKindGenerator}; +use generators::{AstGenerator, AstKindGenerator, VisitGenerator}; use linker::{linker, Linker}; use schema::{Inherit, Module, REnum, RStruct, RType, Schema}; @@ -191,6 +192,7 @@ fn main() -> std::result::Result<(), Box> { .with(AstGenerator) .with(AstKindGenerator) .with(ImplGetSpanGenerator) + .with(VisitGenerator) .generate()?; let output_dir = output_dir()?; @@ -218,6 +220,17 @@ fn main() -> std::result::Result<(), Box> { file.write_all(span_content.as_bytes())?; } + { + // write `visit.rs` file + let output = outputs[VisitGenerator.name()].as_many(); + let span_content = pprint(&output["visit"]); + + let path = format!("{output_dir}/visit.rs"); + let mut file = fs::File::create(path)?; + + file.write_all(span_content.as_bytes())?; + } + cargo_fmt(".")?; // let schema = serde_json::to_string_pretty(&schema).map_err(|e| e.to_string())?; diff --git a/tasks/ast_codegen/src/schema.rs b/tasks/ast_codegen/src/schema.rs index c478004600216..a2c3019075050 100644 --- a/tasks/ast_codegen/src/schema.rs +++ b/tasks/ast_codegen/src/schema.rs @@ -27,7 +27,7 @@ pub struct Definitions { #[derive(Debug, Clone)] pub enum Inherit { Unlinked(String), - Linked { super_: String, variants: Punctuated }, + Linked { super_: Type, variants: Punctuated }, } impl From for Inherit { diff --git a/tasks/ast_codegen/src/util.rs b/tasks/ast_codegen/src/util.rs new file mode 100644 index 0000000000000..7f94fb5908698 --- /dev/null +++ b/tasks/ast_codegen/src/util.rs @@ -0,0 +1,154 @@ +use itertools::Itertools; +use proc_macro2::{Group, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::{GenericArgument, Ident, PathArguments, Type, TypePath}; + +pub trait TokenStreamExt { + fn replace_ident(self, needle: &str, replace: &Ident) -> TokenStream; +} + +pub trait TypeExt { + fn get_ident(&self) -> TypeIdentResult; +} + +pub trait StrExt: AsRef { + /// Dead simple, just adds either `s` or `es` based on the last character. + /// doesn't handle things like `sh`, `x`, `z`, etc. It also creates wrong results when the word + /// ends with `y` but there is a preceding vowl similar to `toys`, + /// It WILL output the WRONG result `toies`! + /// As an edge case would output `children` for the input `child`. + fn to_plural(self) -> String; +} + +#[derive(Debug)] +pub enum TypeIdentResult<'a> { + Ident(&'a Ident), + Vec(Box>), + Box(Box>), + Option(Box>), + Reference(Box>), +} + +impl<'a> TypeIdentResult<'a> { + fn boxed(inner: Self) -> Self { + Self::Box(Box::new(inner)) + } + + fn vec(inner: Self) -> Self { + Self::Vec(Box::new(inner)) + } + + fn option(inner: Self) -> Self { + Self::Option(Box::new(inner)) + } + + fn reference(inner: Self) -> Self { + Self::Reference(Box::new(inner)) + } + + pub fn inner_ident(&self) -> &'a Ident { + match self { + Self::Ident(it) => it, + Self::Vec(it) | Self::Box(it) | Self::Option(it) | Self::Reference(it) => { + it.inner_ident() + } + } + } + + pub fn as_ident(&self) -> Option<&'a Ident> { + if let Self::Ident(it) = self { + Some(it) + } else { + None + } + } +} + +impl TypeExt for Type { + fn get_ident(&self) -> TypeIdentResult { + match self { + Type::Path(TypePath { path, .. }) => { + let seg1 = path.segments.first().unwrap(); + match &seg1.arguments { + PathArguments::None => TypeIdentResult::Ident(&seg1.ident), + PathArguments::AngleBracketed(it) => { + let args = &it.args.iter().collect_vec(); + assert!(args.len() < 3, "Max path arguments here is 2, eg `Box<'a, Adt>`"); + if let Some(second) = args.get(1) { + let GenericArgument::Type(second) = second else { panic!() }; + let inner = second.get_ident(); + if seg1.ident == "Box" { + TypeIdentResult::boxed(inner) + } else if seg1.ident == "Vec" { + TypeIdentResult::vec(inner) + } else { + panic!(); + } + } else { + match args.first() { + Some(GenericArgument::Type(it)) => { + let inner = it.get_ident(); + if seg1.ident == "Option" { + TypeIdentResult::option(inner) + } else { + inner + } + } + Some(GenericArgument::Lifetime(_)) => { + TypeIdentResult::Ident(&seg1.ident) + } + _ => panic!("unsupported type!"), + } + } + } + PathArguments::Parenthesized(_) => { + panic!("Parenthesized path arguments aren't supported!") + } + } + } + Type::Reference(typ) => TypeIdentResult::reference(typ.elem.get_ident()), + _ => panic!("Unsupported type."), + } + } +} + +impl> StrExt for T { + fn to_plural(self) -> String { + let txt = self.as_ref(); + if txt.is_empty() { + return String::default(); + } + + let mut txt = txt.to_string(); + if txt.ends_with("child") { + txt.push_str("ren"); + } else { + match txt.chars().last() { + Some('s') => { + txt.push_str("es"); + } + Some('y') => { + txt.pop(); + txt.push_str("ies"); + } + _ => txt.push('s'), + } + } + txt + } +} + +impl TokenStreamExt for TokenStream { + fn replace_ident(self, needle: &str, replace: &Ident) -> TokenStream { + self.into_iter() + .map(|it| match it { + TokenTree::Ident(ident) if ident == needle => replace.to_token_stream(), + TokenTree::Group(group) => { + Group::new(group.delimiter(), group.stream().replace_ident(needle, replace)) + .to_token_stream() + } + _ => it.to_token_stream(), + }) + .collect() + } +}