diff --git a/.github/.generated_ast_watch_list.yml b/.github/.generated_ast_watch_list.yml index 818c9ee29d83c..9446a15aebd5a 100644 --- a/.github/.generated_ast_watch_list.yml +++ b/.github/.generated_ast_watch_list.yml @@ -27,5 +27,6 @@ src: - 'crates/oxc_ast/src/generated/ast_builder.rs' - 'crates/oxc_ast/src/generated/visit.rs' - 'crates/oxc_ast/src/generated/visit_mut.rs' + - 'crates/oxc_ast_macros/src/generated/ast.rs' - 'tasks/ast_tools/src/**' - '.github/.generated_ast_watch_list.yml' diff --git a/Cargo.lock b/Cargo.lock index 00ec0e5872991..cb3330ac070bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1464,11 +1464,6 @@ dependencies = [ [[package]] name = "oxc_ast_macros" version = "0.29.0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "oxc_ast_tools" @@ -1482,6 +1477,7 @@ dependencies = [ "proc-macro2", "quote", "regex", + "rustc-hash", "serde", "serde_json", "syn", diff --git a/crates/oxc_ast_macros/Cargo.toml b/crates/oxc_ast_macros/Cargo.toml index 216d8ac59f13e..7e6101e7ea4fa 100644 --- a/crates/oxc_ast_macros/Cargo.toml +++ b/crates/oxc_ast_macros/Cargo.toml @@ -19,8 +19,3 @@ workspace = true [lib] proc-macro = true doctest = false - -[dependencies] -proc-macro2 = { workspace = true } -quote = { workspace = true } -syn = { workspace = true, features = ["full", "parsing", "printing", "proc-macro"] } diff --git a/crates/oxc_ast_macros/src/ast.rs b/crates/oxc_ast_macros/src/ast.rs deleted file mode 100644 index 083233f874390..0000000000000 --- a/crates/oxc_ast_macros/src/ast.rs +++ /dev/null @@ -1,96 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; -use syn::{punctuated::Punctuated, token::Comma, Attribute, Fields, Ident, Item, ItemEnum}; - -pub fn ast(input: &Item) -> TokenStream { - let (head, tail) = match input { - Item::Enum(enum_) => (enum_repr(enum_), assert_generated_derives(&enum_.attrs)), - Item::Struct(struct_) => (quote!(#[repr(C)]), assert_generated_derives(&struct_.attrs)), - _ => unreachable!(), - }; - - quote! { - #[derive(::oxc_ast_macros::Ast)] - #head - #input - #tail - } -} - -/// If `enum_` has any non-unit variant, returns `#[repr(C, u8)]`, otherwise returns `#[repr(u8)]`. -fn enum_repr(enum_: &ItemEnum) -> TokenStream { - if enum_.variants.iter().any(|var| !matches!(var.fields, Fields::Unit)) { - quote!(#[repr(C, u8)]) - } else { - quote!(#[repr(u8)]) - } -} - -/// Generate assertions that traits used in `#[generate_derive]` are in scope. -/// -/// e.g. for `#[generate_derive(GetSpan)]`, it generates: -/// -/// ```rs -/// const _: () = { -/// { -/// trait AssertionTrait: ::oxc_span::GetSpan {} -/// impl AssertionTrait for T {} -/// } -/// }; -/// ``` -/// -/// If `GetSpan` is not in scope, or it is not the correct `oxc_span::GetSpan`, -/// this will raise a compilation error. -fn assert_generated_derives(attrs: &[Attribute]) -> TokenStream { - // NOTE: At this level we don't care if a trait is derived multiple times, It is the - // responsibility of the `ast_tools` to raise errors for those. - let assertion = attrs - .iter() - .filter(|attr| attr.path().is_ident("generate_derive")) - .flat_map(parse_attr) - .map(|derive| { - let (abs_derive, generics) = abs_trait(&derive); - quote! {{ - // NOTE: these are wrapped in a scope to avoid the need for unique identifiers. - trait AssertionTrait: #abs_derive #generics {} - impl AssertionTrait for T {} - }} - }); - quote!(const _: () = { #(#assertion)* };) -} - -#[inline] -fn parse_attr(attr: &Attribute) -> impl Iterator { - attr.parse_args_with(Punctuated::::parse_terminated) - .expect("`#[generate_derive]` only accepts traits as single segment paths. Found an invalid argument.") - .into_iter() -} - -// TODO: benchmark this to see if a lazy static cell containing `HashMap` would perform better. -#[inline] -fn abs_trait( - ident: &Ident, -) -> (/* absolute type path */ TokenStream, /* possible generics */ TokenStream) { - if ident == "CloneIn" { - (quote!(::oxc_allocator::CloneIn), quote!(<'static>)) - } else if ident == "GetSpan" { - (quote!(::oxc_span::GetSpan), TokenStream::default()) - } else if ident == "GetSpanMut" { - (quote!(::oxc_span::GetSpanMut), TokenStream::default()) - } else if ident == "ContentEq" { - (quote!(::oxc_span::cmp::ContentEq), TokenStream::default()) - } else if ident == "ContentHash" { - (quote!(::oxc_span::hash::ContentHash), TokenStream::default()) - } else { - invalid_derive(ident) - } -} - -#[cold] -fn invalid_derive(ident: &Ident) -> ! { - panic!( - "Invalid derive trait(generate_derive): {ident}.\n\ - Help: If you are trying to implement a new `generate_derive` trait, \ - make sure to add it to the list in `abs_trait` function." - ) -} diff --git a/crates/oxc_ast_macros/src/generated/ast.rs b/crates/oxc_ast_macros/src/generated/ast.rs new file mode 100644 index 0000000000000..45968f409680c --- /dev/null +++ b/crates/oxc_ast_macros/src/generated/ast.rs @@ -0,0 +1,581 @@ +// Auto-generated code, DO NOT EDIT DIRECTLY! +// To edit this generated file you have to edit `tasks/ast_tools/src/generators/ast_macro.rs` + +#![allow(clippy::useless_conversion)] + +#[allow(unused_imports)] +use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; + +pub fn gen(name: &str, input: TokenStream) -> TokenStream { + match name { + "AssignmentOperator" => gen_assignment_operator(input), + "BooleanLiteral" => gen_boolean_literal(input), + "JSXElement" => gen_jsx_element(input), + "NumberBase" => gen_number_base(input), + "Program" => gen_program(input), + "RegularExpression" => gen_regular_expression(input), + "TSThisParameter" => gen_ts_this_parameter(input), + "AccessorProperty" + | "Alternative" + | "ArrayAssignmentTarget" + | "ArrayExpression" + | "ArrayPattern" + | "ArrowFunctionExpression" + | "AssignmentExpression" + | "AssignmentPattern" + | "AssignmentTargetPropertyIdentifier" + | "AssignmentTargetPropertyProperty" + | "AssignmentTargetRest" + | "AssignmentTargetWithDefault" + | "AwaitExpression" + | "BigIntLiteral" + | "BinaryExpression" + | "BindingIdentifier" + | "BindingPattern" + | "BindingProperty" + | "BindingRestElement" + | "BlockStatement" + | "BoundaryAssertion" + | "BreakStatement" + | "CallExpression" + | "CapturingGroup" + | "CatchClause" + | "CatchParameter" + | "ChainExpression" + | "Character" + | "CharacterClass" + | "CharacterClassEscape" + | "CharacterClassRange" + | "Class" + | "ClassBody" + | "ClassString" + | "ClassStringDisjunction" + | "ComputedMemberExpression" + | "ConditionalExpression" + | "ContinueStatement" + | "DebuggerStatement" + | "Decorator" + | "Directive" + | "Disjunction" + | "DoWhileStatement" + | "Dot" + | "Elision" + | "EmptyObject" + | "EmptyStatement" + | "ExportAllDeclaration" + | "ExportDefaultDeclaration" + | "ExportNamedDeclaration" + | "ExportSpecifier" + | "ExpressionStatement" + | "Flags" + | "ForInStatement" + | "ForOfStatement" + | "ForStatement" + | "FormalParameter" + | "FormalParameters" + | "Function" + | "FunctionBody" + | "Hashbang" + | "IdentifierName" + | "IdentifierReference" + | "IfStatement" + | "IgnoreGroup" + | "ImportAttribute" + | "ImportDeclaration" + | "ImportDefaultSpecifier" + | "ImportExpression" + | "ImportNamespaceSpecifier" + | "ImportSpecifier" + | "IndexedReference" + | "JSDocNonNullableType" + | "JSDocNullableType" + | "JSDocUnknownType" + | "JSXAttribute" + | "JSXClosingElement" + | "JSXClosingFragment" + | "JSXEmptyExpression" + | "JSXExpressionContainer" + | "JSXFragment" + | "JSXIdentifier" + | "JSXMemberExpression" + | "JSXNamespacedName" + | "JSXOpeningElement" + | "JSXOpeningFragment" + | "JSXSpreadAttribute" + | "JSXSpreadChild" + | "JSXText" + | "LabelIdentifier" + | "LabeledStatement" + | "LogicalExpression" + | "LookAroundAssertion" + | "MetaProperty" + | "MethodDefinition" + | "ModifierFlags" + | "NamedReference" + | "NewExpression" + | "NullLiteral" + | "NumericLiteral" + | "ObjectAssignmentTarget" + | "ObjectExpression" + | "ObjectPattern" + | "ObjectProperty" + | "ParenthesizedExpression" + | "Pattern" + | "PrivateFieldExpression" + | "PrivateIdentifier" + | "PrivateInExpression" + | "PropertyDefinition" + | "Quantifier" + | "RegExp" + | "RegExpLiteral" + | "ReturnStatement" + | "SequenceExpression" + | "SourceType" + | "Span" + | "SpreadElement" + | "StaticBlock" + | "StaticMemberExpression" + | "StringLiteral" + | "Super" + | "SwitchCase" + | "SwitchStatement" + | "TSAnyKeyword" + | "TSArrayType" + | "TSAsExpression" + | "TSBigIntKeyword" + | "TSBooleanKeyword" + | "TSCallSignatureDeclaration" + | "TSClassImplements" + | "TSConditionalType" + | "TSConstructSignatureDeclaration" + | "TSConstructorType" + | "TSEnumDeclaration" + | "TSEnumMember" + | "TSExportAssignment" + | "TSExternalModuleReference" + | "TSFunctionType" + | "TSImportAttribute" + | "TSImportAttributes" + | "TSImportEqualsDeclaration" + | "TSImportType" + | "TSIndexSignature" + | "TSIndexSignatureName" + | "TSIndexedAccessType" + | "TSInferType" + | "TSInstantiationExpression" + | "TSInterfaceBody" + | "TSInterfaceDeclaration" + | "TSInterfaceHeritage" + | "TSIntersectionType" + | "TSIntrinsicKeyword" + | "TSLiteralType" + | "TSMappedType" + | "TSMethodSignature" + | "TSModuleBlock" + | "TSModuleDeclaration" + | "TSNamedTupleMember" + | "TSNamespaceExportDeclaration" + | "TSNeverKeyword" + | "TSNonNullExpression" + | "TSNullKeyword" + | "TSNumberKeyword" + | "TSObjectKeyword" + | "TSOptionalType" + | "TSParenthesizedType" + | "TSPropertySignature" + | "TSQualifiedName" + | "TSRestType" + | "TSSatisfiesExpression" + | "TSStringKeyword" + | "TSSymbolKeyword" + | "TSTemplateLiteralType" + | "TSThisType" + | "TSTupleType" + | "TSTypeAliasDeclaration" + | "TSTypeAnnotation" + | "TSTypeAssertion" + | "TSTypeLiteral" + | "TSTypeOperator" + | "TSTypeParameter" + | "TSTypeParameterDeclaration" + | "TSTypeParameterInstantiation" + | "TSTypePredicate" + | "TSTypeQuery" + | "TSTypeReference" + | "TSUndefinedKeyword" + | "TSUnionType" + | "TSUnknownKeyword" + | "TSVoidKeyword" + | "TaggedTemplateExpression" + | "TemplateElement" + | "TemplateElementValue" + | "TemplateLiteral" + | "ThisExpression" + | "ThrowStatement" + | "TryStatement" + | "UnaryExpression" + | "UnicodePropertyEscape" + | "UpdateExpression" + | "VariableDeclaration" + | "VariableDeclarator" + | "WhileStatement" + | "WithClause" + | "WithStatement" + | "YieldExpression" => repr_c(input), + "Argument" + | "ArrayExpressionElement" + | "AssignmentTargetMaybeDefault" + | "AssignmentTargetPattern" + | "AssignmentTargetProperty" + | "BindingPatternKind" + | "ChainElement" + | "CharacterClassContents" + | "ClassElement" + | "Declaration" + | "ExportDefaultDeclarationKind" + | "Expression" + | "ForStatementInit" + | "ForStatementLeft" + | "ImportAttributeKey" + | "ImportDeclarationSpecifier" + | "JSXAttributeItem" + | "JSXAttributeName" + | "JSXAttributeValue" + | "JSXChild" + | "JSXElementName" + | "JSXExpression" + | "JSXMemberExpressionObject" + | "MemberExpression" + | "ModuleDeclaration" + | "ModuleExportName" + | "ObjectPropertyKind" + | "PropertyKey" + | "RegExpPattern" + | "SimpleAssignmentTarget" + | "Statement" + | "TSEnumMemberName" + | "TSImportAttributeName" + | "TSLiteral" + | "TSModuleDeclarationBody" + | "TSModuleDeclarationName" + | "TSModuleReference" + | "TSSignature" + | "TSTupleElement" + | "TSType" + | "TSTypeName" + | "TSTypePredicateName" + | "TSTypeQueryExprName" + | "Term" => repr_c_u8(input), + "AccessorPropertyType" + | "AssignmentTarget" + | "BigintBase" + | "BinaryOperator" + | "BoundaryAssertionKind" + | "CharacterClassContentsKind" + | "CharacterClassEscapeKind" + | "CharacterKind" + | "ClassType" + | "FormalParameterKind" + | "FunctionType" + | "ImportOrExportKind" + | "Language" + | "LanguageVariant" + | "LogicalOperator" + | "LookAroundAssertionKind" + | "MethodDefinitionKind" + | "MethodDefinitionType" + | "ModuleKind" + | "PropertyDefinitionType" + | "PropertyKind" + | "TSAccessibility" + | "TSMappedTypeModifierOperator" + | "TSMethodSignatureKind" + | "TSModuleDeclarationKind" + | "TSTypeOperatorOperator" + | "UnaryOperator" + | "UpdateOperator" + | "VariableDeclarationKind" => repr_u8(input), + _ => unreachable!(), + } +} + +fn gen_boolean_literal(input: TokenStream) -> TokenStream { + let mut stream = repr_c(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream.extend(assert_get_span()); + stream.extend(assert_get_span_mut()); + stream +} + +fn gen_program(input: TokenStream) -> TokenStream { + let mut stream = repr_c(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream.extend(assert_get_span()); + stream.extend(assert_get_span_mut()); + stream +} + +fn gen_ts_this_parameter(input: TokenStream) -> TokenStream { + let mut stream = repr_c(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream.extend(assert_get_span()); + stream.extend(assert_get_span_mut()); + stream +} + +fn gen_jsx_element(input: TokenStream) -> TokenStream { + let mut stream = repr_c(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream.extend(assert_get_span()); + stream.extend(assert_get_span_mut()); + stream +} + +fn gen_number_base(input: TokenStream) -> TokenStream { + let mut stream = repr_u8(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream +} + +fn gen_assignment_operator(input: TokenStream) -> TokenStream { + let mut stream = repr_u8(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream +} + +fn gen_regular_expression(input: TokenStream) -> TokenStream { + let mut stream = repr_c(input); + stream.extend(assert_clone_in()); + stream.extend(assert_content_eq()); + stream.extend(assert_content_hash()); + stream +} + +fn derive_ast() -> TokenStream { + [ + TokenTree::Punct(Punct::new('#', Spacing::Alone)), + TokenTree::Group(Group::new( + Delimiter::Bracket, + [ + TokenTree::Ident(Ident::new("derive", Span::call_site())), + TokenTree::Group(Group::new( + Delimiter::Parenthesis, + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_ast_macros", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("Ast", Span::call_site())), + ] + .into_iter() + .collect(), + )), + ] + .into_iter() + .collect(), + )), + ] + .into_iter() + .collect() +} + +fn repr_c(input: TokenStream) -> TokenStream { + repr(TokenStream::from(TokenTree::Ident(Ident::new("C", Span::call_site()))), input) +} + +fn repr_u8(input: TokenStream) -> TokenStream { + repr(TokenStream::from(TokenTree::Ident(Ident::new("u8", Span::call_site()))), input) +} + +fn repr_c_u8(input: TokenStream) -> TokenStream { + repr( + [ + TokenTree::Ident(Ident::new("C", Span::call_site())), + TokenTree::Punct(Punct::new(',', Spacing::Alone)), + TokenTree::Ident(Ident::new("u8", Span::call_site())), + ] + .into_iter() + .collect(), + input, + ) +} + +fn repr(rep: TokenStream, input: TokenStream) -> TokenStream { + let mut stream = derive_ast(); + stream.extend(repr_raw(rep)); + stream.extend(input); + stream +} + +fn repr_raw(rep: TokenStream) -> TokenStream { + [ + TokenTree::Punct(Punct::new('#', Spacing::Alone)), + TokenTree::Group(Group::new( + Delimiter::Bracket, + [ + TokenTree::Ident(Ident::new("repr", Span::call_site())), + TokenTree::Group(Group::new(Delimiter::Parenthesis, rep.into_iter().collect())), + ] + .into_iter() + .collect(), + )), + ] + .into_iter() + .collect() +} + +fn assert_clone_in() -> TokenStream { + assert( + [ + TokenTree::Ident(Ident::new("CloneIn", Span::call_site())), + TokenTree::Punct(Punct::new('<', Spacing::Alone)), + TokenTree::Punct(Punct::new('\'', Spacing::Joint)), + TokenTree::Ident(Ident::new("static", Span::call_site())), + TokenTree::Punct(Punct::new('>', Spacing::Alone)), + ] + .into_iter() + .collect(), + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_allocator", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("CloneIn", Span::call_site())), + TokenTree::Punct(Punct::new('<', Spacing::Alone)), + TokenTree::Punct(Punct::new('\'', Spacing::Joint)), + TokenTree::Ident(Ident::new("static", Span::call_site())), + TokenTree::Punct(Punct::new('>', Spacing::Alone)), + ] + .into_iter() + .collect(), + ) +} + +fn assert_get_span() -> TokenStream { + assert( + TokenStream::from(TokenTree::Ident(Ident::new("GetSpan", Span::call_site()))), + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_span", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("GetSpan", Span::call_site())), + ] + .into_iter() + .collect(), + ) +} + +fn assert_get_span_mut() -> TokenStream { + assert( + TokenStream::from(TokenTree::Ident(Ident::new("GetSpanMut", Span::call_site()))), + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_span", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("GetSpanMut", Span::call_site())), + ] + .into_iter() + .collect(), + ) +} + +fn assert_content_eq() -> TokenStream { + assert( + TokenStream::from(TokenTree::Ident(Ident::new("ContentEq", Span::call_site()))), + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_span", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("cmp", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("ContentEq", Span::call_site())), + ] + .into_iter() + .collect(), + ) +} + +fn assert_content_hash() -> TokenStream { + assert( + TokenStream::from(TokenTree::Ident(Ident::new("ContentHash", Span::call_site()))), + [ + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("oxc_span", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("hash", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("ContentHash", Span::call_site())), + ] + .into_iter() + .collect(), + ) +} + +fn assert(name: TokenStream, path: TokenStream) -> TokenStream { + [ + TokenTree::Ident(Ident::new("const", Span::call_site())), + TokenTree::Ident(Ident::new("_", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Group(Group::new(Delimiter::Parenthesis, TokenStream::new())), + TokenTree::Punct(Punct::new('=', Spacing::Alone)), + TokenTree::Group(Group::new( + Delimiter::Brace, + [ + TokenTree::Ident(Ident::new("trait", Span::call_site())), + TokenTree::Ident(Ident::new("AssertionTrait", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + ] + .into_iter() + .chain(path.into_iter()) + .chain( + [ + TokenTree::Group(Group::new(Delimiter::Brace, TokenStream::new())), + TokenTree::Ident(Ident::new("impl", Span::call_site())), + TokenTree::Punct(Punct::new('<', Spacing::Alone)), + TokenTree::Ident(Ident::new("T", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + ] + .into_iter(), + ) + .chain(name.into_iter()) + .chain( + [ + TokenTree::Punct(Punct::new('>', Spacing::Alone)), + TokenTree::Ident(Ident::new("AssertionTrait", Span::call_site())), + TokenTree::Ident(Ident::new("for", Span::call_site())), + TokenTree::Ident(Ident::new("T", Span::call_site())), + TokenTree::Group(Group::new(Delimiter::Brace, TokenStream::new())), + ] + .into_iter(), + ) + .collect(), + )), + TokenTree::Punct(Punct::new(';', Spacing::Alone)), + ] + .into_iter() + .collect() +} diff --git a/crates/oxc_ast_macros/src/lib.rs b/crates/oxc_ast_macros/src/lib.rs index f657bff3281d2..d6ae54c66412a 100644 --- a/crates/oxc_ast_macros/src/lib.rs +++ b/crates/oxc_ast_macros/src/lib.rs @@ -1,7 +1,8 @@ -use proc_macro::TokenStream; -use syn::{parse_macro_input, Item}; +use proc_macro::{TokenStream, TokenTree}; -mod ast; +mod generated { + pub mod ast; +} /// This attribute serves two purposes. /// First, it is a marker for our `ast_tools` to detect AST types. @@ -70,9 +71,8 @@ mod ast; /// 2. `tsify` #[proc_macro_attribute] pub fn ast(_args: TokenStream, input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as Item); - let expanded = ast::ast(&input); - TokenStream::from(expanded) + let name = get_type_name(&input); + generated::ast::gen(&name, input) } /// Dummy derive macro for a non-existent trait `Ast`. @@ -85,3 +85,38 @@ pub fn ast(_args: TokenStream, input: TokenStream) -> TokenStream { pub fn ast_derive(_input: TokenStream) -> TokenStream { TokenStream::new() } + +/// Get type name as a string. +/// +/// # Panics +/// Panics if type is not `struct` or `enum`. +fn get_type_name(input: &TokenStream) -> String { + let mut it = input.clone().into_iter(); + let mut next = || match it.next() { + Some(tt) => tt, + None => unreachable!(), + }; + + loop { + let TokenTree::Ident(ident) = next() else { continue }; + let mut ident_str = ident.to_string(); + + // Skip over `pub`, `pub(crate)`, `pub(super)` + if ident_str.as_str() == "pub" { + let mut tt = next(); + if matches!(tt, TokenTree::Group(_)) { + tt = next(); + } + let TokenTree::Ident(ident) = tt else { unreachable!() }; + ident_str = ident.to_string(); + } + + assert!( + matches!(ident_str.as_str(), "struct" | "enum"), + "`#[ast] attr can only be applied to structs and enums" + ); + + let TokenTree::Ident(ident) = next() else { unreachable!() }; + return ident.to_string(); + } +} diff --git a/tasks/ast_tools/Cargo.toml b/tasks/ast_tools/Cargo.toml index fb30bbdcf2858..427ac0533ac7d 100644 --- a/tasks/ast_tools/Cargo.toml +++ b/tasks/ast_tools/Cargo.toml @@ -22,6 +22,7 @@ prettyplease = { workspace = true } proc-macro2 = { workspace = true } quote = { workspace = true } regex = { workspace = true } +rustc-hash = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } syn = { workspace = true, features = ["clone-impls", "derive", "extra-traits", "full", "parsing", "printing", "proc-macro"] } diff --git a/tasks/ast_tools/src/fmt.rs b/tasks/ast_tools/src/fmt.rs index df45187dba330..ad51cdb68bfe2 100644 --- a/tasks/ast_tools/src/fmt.rs +++ b/tasks/ast_tools/src/fmt.rs @@ -11,7 +11,10 @@ static WHITE_SPACES: &str = " \t"; /// Pretty print pub fn pretty_print(input: &TokenStream) -> String { - let result = prettyplease::unparse(&parse_file(input.to_string().as_str()).unwrap()); + let mut result = input.to_string(); + if let Ok(file) = parse_file(&result) { + result = prettyplease::unparse(&file); + } // `insert!` and `endl!` macros are not currently used // let result = ENDL_REGEX.replace_all(&result, EndlReplacer); // let result = INSERT_REGEX.replace_all(&result, InsertReplacer).to_string(); diff --git a/tasks/ast_tools/src/generators/ast_macro.rs b/tasks/ast_tools/src/generators/ast_macro.rs new file mode 100644 index 0000000000000..25c5508d55a0b --- /dev/null +++ b/tasks/ast_tools/src/generators/ast_macro.rs @@ -0,0 +1,315 @@ +//! Generator to generate the implementation of `#[ast]` macro. + +use std::borrow::Cow; + +use convert_case::{Case, Casing}; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use rustc_hash::{FxHashMap, FxHashSet}; + +use super::define_generator; +use crate::{ + codegen::{generated_header, LateCtx}, + output, + schema::TypeDef, + to_code::{code, to_code}, + Generator, GeneratorOutput, AST_MACROS_CRATE, +}; + +define_generator! { + pub struct AstMacroGenerator; +} + +enum Repr { + C, + U8, + CU8, +} + +struct TypeInfo<'t> { + def: &'t TypeDef, + repr: Repr, + first_in_module: Option<&'t str>, +} + +impl Generator for AstMacroGenerator { + fn generate(&mut self, ctx: &LateCtx) -> GeneratorOutput { + // Get info about types, and record what derive assertions are required for each module. + // Instead of outputting same derive assertions for lots of types, output them only once in + // each module, in the expansion of `#[ast]` macro for 1st type in that module. + let mut modules = FxHashMap::default(); + let type_infos = ctx + .schema() + .into_iter() + .map(|def| get_type_info(def, &mut modules)) + .collect::>(); + + // Generate derive assertion functions + let (trait_names_and_assert_fn_names, assert_fns) = get_assert_fns(); + + // Generate generator functions, and record all the types which use each generator function + let mut gen_fns = vec![]; + let mut fn_types: FxHashMap, Vec<&str>> = FxHashMap::default(); + for info in type_infos { + let gen_fn_name = + generate_type_fn(&info, &trait_names_and_assert_fn_names, &modules, &mut gen_fns); + let type_names = fn_types.entry(gen_fn_name).or_default(); + type_names.push(info.def.name()); + } + + // Generate match arms for `gen` function + let mut fn_types = fn_types.into_iter().collect::>(); + fn_types.sort_unstable_by(|(gen_fn_name1, _), (gen_fn_name2, _)| { + gen_fn_name1.cmp(gen_fn_name2) + }); + let match_arms = fn_types.into_iter().map(|(gen_fn_name, mut type_names)| { + let gen_fn_name = Ident::new(&gen_fn_name, Span::call_site()); + type_names.sort_unstable(); + quote! { #(#type_names)|* => #gen_fn_name(input) } + }); + + let header = generated_header!(); + + let derive_ast_fn = get_derive_ast_fn(); + let repr_fns = get_repr_fns(); + + GeneratorOutput( + output(AST_MACROS_CRATE, "ast.rs"), + quote! { + #header + + #![allow(clippy::useless_conversion)] + + ///@@line_break + #[allow(unused_imports)] + use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; + + ///@@line_break + pub fn gen(name: &str, input: TokenStream) -> TokenStream { + match name { + #(#match_arms,)* + _ => unreachable!(), + } + } + + ///@@line_break + #(#gen_fns)* + + ///@@line_break + #derive_ast_fn + + ///@@line_break + #repr_fns + + ///@@line_break + #assert_fns + }, + ) + } +} + +/// Get info about type. +/// +/// * Determine what `#[repr]` attribute it needs. +/// * Determine if is first type in its module +/// (in which case all derive assertions for the module will be made part of it's `#[ast] expansion`). +/// * Record what derive assertions it needs in `modules`. +fn get_type_info<'t>( + def: &'t TypeDef, + modules: &mut FxHashMap>, +) -> TypeInfo<'t> { + let (repr, derives, module_path) = match def { + TypeDef::Struct(struct_) => { + let repr = Repr::C; + let derives = &*struct_.generated_derives; + let module_path = &*struct_.module_path; + (repr, derives, module_path) + } + TypeDef::Enum(enum_) => { + let repr = if enum_.variants.iter().any(|variant| !variant.fields.is_empty()) { + Repr::CU8 + } else { + Repr::U8 + }; + let derives = &*enum_.generated_derives; + let module_path = &*enum_.module_path; + (repr, derives, module_path) + } + }; + + let (first_in_module, recorded_derives) = if let Some(derives) = modules.get_mut(module_path) { + (None, derives) + } else { + modules.insert(module_path.to_string(), FxHashSet::default()); + let recorded_derives = modules.get_mut(module_path).unwrap(); + (Some(module_path), recorded_derives) + }; + + for derive in derives { + recorded_derives.insert(derive.clone()); + } + + TypeInfo { def, repr, first_in_module } +} + +/// Generate function to generate macro expansion of `#[ast]` for a type. +/// +/// Store `TokenStream` for that function in `gen_fns`. +/// Or if `#[ast]` expansion for this type does not include derive assertions, +/// use one of the `repr_*` functions. +fn generate_type_fn( + info: &TypeInfo, + trait_names_and_assert_fn_names: &[(&str, Ident)], + modules: &FxHashMap>, + gen_fns: &mut Vec, +) -> Cow<'static, str> { + let repr = match info.repr { + Repr::C => "repr_c", + Repr::U8 => "repr_u8", + Repr::CU8 => "repr_c_u8", + }; + + let derives = match info.first_in_module { + Some(module_path) => { + let derives = modules.get(module_path).unwrap(); + if derives.is_empty() { + None + } else { + Some(derives) + } + } + None => None, + }; + + let Some(derives) = derives else { return Cow::Borrowed(repr) }; + + let mut derives = derives.iter().collect::>(); + derives.sort_unstable(); + + let name = info.def.name().as_str(); + let assertions = derives.iter().map(|trait_name| { + let (_, assert_fn_name) = trait_names_and_assert_fn_names + .iter() + .find(|(name, ..)| name == trait_name) + .unwrap_or_else(|| { + panic!( + "Invalid derive trait(generate_derive): {trait_name}.\n\ + Help: If you are trying to implement a new `generate_derive` trait, \ + make sure to add it to the list in `get_assert_fns` function." + ); + }); + quote! { stream.extend(#assert_fn_name()); } + }); + + let gen_fn_name = format!("gen_{}", name.to_case(Case::Snake)); + let gen_fn_ident = Ident::new(&gen_fn_name, Span::call_site()); + let repr = Ident::new(repr, Span::call_site()); + let gen_fn = quote! { + ///@@line_break + fn #gen_fn_ident(input: TokenStream) -> TokenStream { + let mut stream = #repr(input); + #(#assertions)* + stream + } + }; + gen_fns.push(gen_fn); + + Cow::Owned(gen_fn_name) +} + +/// Generate `derive_ast` function which constructs token stream for `#[derive(::oxc_ast_macros::Ast)]` +fn get_derive_ast_fn() -> TokenStream { + let derive_ast = code!( #[derive(::oxc_ast_macros::Ast)] ); + quote! { + fn derive_ast() -> TokenStream { + #derive_ast + } + } +} + +/// Generate `repr` functions which construct token streams for `#[repr(C)]`, `#[repr(u8)]`, `#[repr(C, u8)]` +fn get_repr_fns() -> TokenStream { + let c_ident = code!(C); + let u8_ident = code!(u8); + let c_u8_seq = code!(C, u8); + let repr = code!(#[repr(@{rep})]); + quote! { + ///@@line_break + fn repr_c(input: TokenStream) -> TokenStream { + repr(#c_ident, input) + } + + ///@@line_break + fn repr_u8(input: TokenStream) -> TokenStream { + repr(#u8_ident, input) + } + + ///@@line_break + fn repr_c_u8(input: TokenStream) -> TokenStream { + repr(#c_u8_seq, input) + } + + ///@@line_break + fn repr(rep: TokenStream, input: TokenStream) -> TokenStream { + let mut stream = derive_ast(); + stream.extend(repr_raw(rep)); + stream.extend(input); + stream + } + + ///@@line_break + fn repr_raw(rep: TokenStream) -> TokenStream { + #repr + } + } +} + +/// Generate derive assertion functions +fn get_assert_fns() -> (Vec<(&'static str, Ident)>, TokenStream) { + let (trait_names_and_assert_fn_names, assert_fns): (Vec<_>, Vec<_>) = [ + ("CloneIn", quote!(::oxc_allocator), true), + ("GetSpan", quote!(::oxc_span), false), + ("GetSpanMut", quote!(::oxc_span), false), + ("ContentEq", quote!(::oxc_span::cmp), false), + ("ContentHash", quote!(::oxc_span::hash), false), + ] + .into_iter() + .map(|(trait_name, trait_path, has_lifetime)| { + let trait_ident = Ident::new(trait_name, Span::call_site()); + let lifetime = if has_lifetime { quote!(<'static>) } else { TokenStream::new() }; + let trait_ident = quote! { #trait_ident #lifetime }; + + let fn_name = format!("assert_{}", trait_name.to_case(Case::Snake)); + let fn_name = Ident::new(&fn_name, Span::call_site()); + + let trait_name_code = to_code(trait_ident.clone()); + let trait_path_code = code!(#trait_path :: #trait_ident); + let fn_def = quote! { + ///@@line_break + fn #fn_name() -> TokenStream { + assert(#trait_name_code, #trait_path_code) + } + }; + + ((trait_name, fn_name), fn_def) + }) + .unzip(); + + let assertion = code! { + const _: () = { + // These are wrapped in a scope to avoid the need for unique identifiers + trait AssertionTrait: @{path} {} + impl AssertionTrait for T {} + }; + }; + let assert_fns = quote! { + #(#assert_fns)* + + ///@@line_break + fn assert(name: TokenStream, path: TokenStream) -> TokenStream { + #assertion + } + }; + + (trait_names_and_assert_fn_names, assert_fns) +} diff --git a/tasks/ast_tools/src/generators/mod.rs b/tasks/ast_tools/src/generators/mod.rs index e16f356d06f41..b684175d9139c 100644 --- a/tasks/ast_tools/src/generators/mod.rs +++ b/tasks/ast_tools/src/generators/mod.rs @@ -7,11 +7,13 @@ use crate::codegen::LateCtx; mod assert_layouts; mod ast_builder; mod ast_kind; +mod ast_macro; mod visit; pub use assert_layouts::AssertLayouts; pub use ast_builder::AstBuilderGenerator; pub use ast_kind::AstKindGenerator; +pub use ast_macro::AstMacroGenerator; pub use visit::{VisitGenerator, VisitMutGenerator}; /// Inserts a newline in the `TokenStream`. diff --git a/tasks/ast_tools/src/main.rs b/tasks/ast_tools/src/main.rs index 6e8dcf94ed346..4c49a7ca018ab 100644 --- a/tasks/ast_tools/src/main.rs +++ b/tasks/ast_tools/src/main.rs @@ -15,13 +15,14 @@ mod markers; mod passes; mod rust_ast; mod schema; +mod to_code; mod util; use derives::{DeriveCloneIn, DeriveContentEq, DeriveContentHash, DeriveGetSpan, DeriveGetSpanMut}; use fmt::cargo_fmt; use generators::{ - AssertLayouts, AstBuilderGenerator, AstKindGenerator, Generator, GeneratorOutput, - VisitGenerator, VisitMutGenerator, + AssertLayouts, AstBuilderGenerator, AstKindGenerator, AstMacroGenerator, Generator, + GeneratorOutput, VisitGenerator, VisitMutGenerator, }; use passes::{CalcLayout, Linker}; use util::{write_all_to, NormalizeError}; @@ -39,6 +40,7 @@ static SOURCE_PATHS: &[&str] = &[ ]; const AST_CRATE: &str = "crates/oxc_ast"; +const AST_MACROS_CRATE: &str = "crates/oxc_ast_macros"; type Result = std::result::Result; type TypeId = usize; @@ -79,6 +81,7 @@ fn main() -> std::result::Result<(), Box> { .generate(AstBuilderGenerator) .generate(VisitGenerator) .generate(VisitMutGenerator) + .generate(AstMacroGenerator) .run()?; if !cli_options.dry_run { diff --git a/tasks/ast_tools/src/to_code.rs b/tasks/ast_tools/src/to_code.rs new file mode 100644 index 0000000000000..5499a54964883 --- /dev/null +++ b/tasks/ast_tools/src/to_code.rs @@ -0,0 +1,189 @@ +use std::mem; + +use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, Spacing, TokenStream, TokenTree}; +use quote::quote; +use syn::Lit; + +/// Similar to `quote!` macro, except evaluates to a `TokenStream` of code to *create* the `TokenStream` +/// which `quote!` evaluates to. +/// +/// `code!( foo() )` evaluates to a `TokenStream` containing: +/// +/// ``` +/// TokenStream::from_iter([ +/// TokenTree::Ident(Ident::new("foo", Span::call_site())), +/// TokenTree::Group(Group::new(Delimiter::Parenthesis, TokenStream::new())) +/// ].into_iter()) +/// ``` +/// +/// `@{...}` can be used to inject code into output without conversion. +/// The value in `@{...}` must be a `TokenStream`. +/// e.g. `code!( foo( @{args} ) )` evaluates to a `TokenStream` containing: +/// +/// ``` +/// TokenStream::from_iter([ +/// TokenTree::Ident(Ident::new("foo", Span::call_site())), +/// TokenTree::Group(Group::new(Delimiter::Parenthesis, args)) +/// ].into_iter()) +/// ``` +macro_rules! code { + ($($tt:tt)*) => { + $crate::to_code::to_code(quote!($($tt)*)) + } +} +pub(crate) use code; + +/// Convert `TokenStream` into Rust code which creates that `TokenStream` +pub fn to_code(stream: TokenStream) -> TokenStream { + stream.to_code() +} + +/// Trait for converting `proc_macro`/`proc_macro2` types into Rust code which creates those types +trait ToCode { + fn to_code(self) -> TokenStream; +} + +impl ToCode for TokenStream { + fn to_code(self) -> TokenStream { + let mut stream = TokenStream::new(); + let mut trees = vec![]; + + let extend = |stream: &mut TokenStream, extend_stream| { + if stream.is_empty() { + *stream = extend_stream; + } else { + stream.extend(quote! { .chain(#extend_stream) }); + } + }; + + let mut it = self.into_iter(); + while let Some(tt) = it.next() { + // Leave contents of `@{...}` unchanged + if let TokenTree::Punct(punct) = &tt { + if punct.as_char() == '@' { + if let Some(TokenTree::Group(group)) = it.clone().next() { + if group.delimiter() == Delimiter::Brace { + it.next().unwrap(); + + if !trees.is_empty() { + let trees = mem::take(&mut trees); + extend(&mut stream, quote! { [ #(#trees),* ].into_iter() }); + } + + let extend_stream = group.stream(); + extend(&mut stream, quote! { #extend_stream.into_iter() }); + continue; + } + } + } + } + + trees.push(tt.to_code()); + } + + if trees.is_empty() { + if stream.is_empty() { + return quote! { TokenStream::new() }; + } + } else { + if stream.is_empty() && trees.len() == 1 { + let tree = trees.into_iter().next().unwrap(); + return quote! { TokenStream::from(#tree) }; + } + + extend(&mut stream, quote! { [ #(#trees),* ].into_iter() }); + } + + quote! { #stream.collect() } + } +} + +impl ToCode for TokenTree { + fn to_code(self) -> TokenStream { + match self { + TokenTree::Ident(ident) => { + let ident = ident.to_code(); + quote! { TokenTree::Ident(#ident) } + } + TokenTree::Punct(punct) => { + let punct = punct.to_code(); + quote! { TokenTree::Punct(#punct) } + } + TokenTree::Literal(literal) => { + let literal = literal.to_code(); + quote! { TokenTree::Literal(#literal) } + } + TokenTree::Group(group) => { + let group = group.to_code(); + quote! { TokenTree::Group(#group) } + } + } + } +} + +impl ToCode for Ident { + fn to_code(self) -> TokenStream { + let name = self.to_string(); + quote! { Ident::new(#name, Span::call_site()) } + } +} + +impl ToCode for Punct { + fn to_code(self) -> TokenStream { + let ch = self.as_char(); + let spacing = self.spacing().to_code(); + quote! { Punct::new(#ch, #spacing) } + } +} + +#[expect(clippy::todo)] +impl ToCode for Literal { + fn to_code(self) -> TokenStream { + let lit: Lit = syn::parse_str(&self.to_string()).unwrap(); + match lit { + Lit::Str(str) => { + let str = str.value(); + quote! { Literal::string(#str) } + } + Lit::Int(int) => match int.suffix() { + "u8" => { + let n = int.base10_parse::().unwrap(); + quote! { Literal::u8_suffixed(#n) } + } + // TODO: Other `Int` types + _ => todo!(), + }, + Lit::Bool(_) => unreachable!(), // `true` and `false` are `Ident`s + // TODO: Other `Lit` types + _ => todo!(), + } + } +} + +impl ToCode for Group { + fn to_code(self) -> TokenStream { + let delimiter = self.delimiter().to_code(); + let stream = self.stream().to_code(); + quote! { Group::new(#delimiter, #stream) } + } +} + +impl ToCode for Delimiter { + fn to_code(self) -> TokenStream { + match self { + Self::Parenthesis => quote!(Delimiter::Parenthesis), + Self::Brace => quote!(Delimiter::Brace), + Self::Bracket => quote!(Delimiter::Bracket), + Self::None => quote!(Delimiter::None), + } + } +} + +impl ToCode for Spacing { + fn to_code(self) -> TokenStream { + match self { + Self::Joint => quote!(Spacing::Joint), + Self::Alone => quote!(Spacing::Alone), + } + } +}