diff --git a/crates/oxc_formatter/src/ast_nodes/node.rs b/crates/oxc_formatter/src/ast_nodes/node.rs index 04f04395b3210..3142bc0f785b0 100644 --- a/crates/oxc_formatter/src/ast_nodes/node.rs +++ b/crates/oxc_formatter/src/ast_nodes/node.rs @@ -2,7 +2,7 @@ use core::fmt; use std::ops::Deref; use oxc_allocator::Allocator; -use oxc_ast::ast::Program; +use oxc_ast::ast::{ExpressionStatement, Program}; use oxc_span::{GetSpan, Span}; use super::AstNodes; @@ -110,3 +110,17 @@ impl<'a> AstNode<'a, Program<'a>> { AstNode { inner, parent, allocator, following_span: None } } } + +impl<'a> AstNode<'a, ExpressionStatement<'a>> { + /// Check if this ExpressionStatement is the body of an arrow function expression + /// + /// Example: + /// `() => expression;` + /// ^^^^^^^^^^ This ExpressionStatement is the body of an arrow function + /// + /// `() => { return expression; }` + /// ^^^^^^^^^^^^^^^^^^^^ This ExpressionStatement is NOT the body of an arrow function + pub fn is_arrow_function_body(&self) -> bool { + matches!(self.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression) + } +} diff --git a/crates/oxc_formatter/src/parentheses/expression.rs b/crates/oxc_formatter/src/parentheses/expression.rs index b3a3b6a368803..0f90a0edcf87a 100644 --- a/crates/oxc_formatter/src/parentheses/expression.rs +++ b/crates/oxc_formatter/src/parentheses/expression.rs @@ -131,10 +131,7 @@ impl NeedsParentheses<'_> for AstNode<'_, IdentifierReference<'_>> { matches!( parent, AstNodes::ExpressionStatement(stmt) if - !matches!( - stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) - if arrow.expression() - ) + !stmt.is_arrow_function_body() ) } } @@ -211,15 +208,7 @@ impl NeedsParentheses<'_> for AstNode<'_, StringLiteral<'_>> { if let AstNodes::ExpressionStatement(stmt) = self.parent { // `() => "foo"` - if let AstNodes::FunctionBody(arrow) = stmt.parent { - if let AstNodes::ArrowFunctionExpression(arrow) = arrow.parent { - !arrow.expression() - } else { - true - } - } else { - true - } + !stmt.is_arrow_function_body() } else { false } @@ -400,21 +389,11 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool { AstNodes::ExpressionStatement(stmt) => { let grand_parent = parent.parent(); - if matches!(grand_parent, AstNodes::FunctionBody(_)) { - let grand_grand_parent = grand_parent.parent(); - if matches!( - grand_grand_parent, - AstNodes::ArrowFunctionExpression(arrow) if arrow.expression() - ) { - // Skip ahead to grand_grand_parent by consuming ancestors - // until we reach it - for ancestor in ancestors.by_ref() { - if core::ptr::eq(ancestor, grand_grand_parent) { - break; - } - } - continue; - } + if stmt.is_arrow_function_body() { + // Skip `FunctionBody` and `ArrowFunctionExpression` + let skipped = ancestors.by_ref().nth(1); + debug_assert!(matches!(skipped, Some(AstNodes::ArrowFunctionExpression(_)))); + continue; } return false; @@ -534,15 +513,11 @@ impl NeedsParentheses<'_> for AstNode<'_, AssignmentExpression<'_>> { // - `{ x } = obj` -> `({ x } = obj)` = needed to prevent parsing as block statement // - `() => { x } = obj` -> `() => ({ x } = obj)` = needed in arrow function body // - `() => a = b` -> `() => (a = b)` = also parens needed - AstNodes::ExpressionStatement(parent) => { - let parent_parent = parent.parent; - if let AstNodes::FunctionBody(body) = parent_parent { - let parent_parent_parent = body.parent; - if matches!(parent_parent_parent, AstNodes::ArrowFunctionExpression(arrow) if arrow.expression()) - { - return true; - } + AstNodes::ExpressionStatement(stmt) => { + if stmt.is_arrow_function_body() { + return true; } + matches!(self.left, AssignmentTarget::ObjectAssignmentTarget(_)) && is_first_in_statement( self.span, @@ -588,12 +563,6 @@ impl NeedsParentheses<'_> for AstNode<'_, AssignmentExpression<'_>> { stmt.update.as_ref().is_some_and(|update| update.span() == self.span()); !(is_initializer || is_update) } - // Arrow functions, only need parens if assignment is the direct body: - // - `() => a = b` -> `() => (a = b)` = needed - // - `() => someFunc(a = b)` = no extra parens needed - AstNodes::ArrowFunctionExpression(arrow) => { - arrow.expression() && arrow.body.span() == self.span() - } // Default: need parentheses in most other contexts // - `new (a = b)` // - `(a = b).prop` @@ -617,8 +586,6 @@ impl NeedsParentheses<'_> for AstNode<'_, SequenceExpression<'_>> { | AstNodes::ForStatement(_) | AstNodes::ExpressionStatement(_) | AstNodes::SequenceExpression(_) - // Handled as part of the arrow function formatting - | AstNodes::ArrowFunctionExpression(_) ) } } @@ -988,8 +955,7 @@ fn is_first_in_statement( match ancestor { AstNodes::ExpressionStatement(stmt) => { - if matches!(stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression) - { + if stmt.is_arrow_function_body() { if mode == FirstInStatementMode::ExpressionStatementOrArrow { if is_not_first_iteration && matches!( diff --git a/crates/oxc_formatter/src/utils/jsx.rs b/crates/oxc_formatter/src/utils/jsx.rs index 5bb5b9cae54ea..97e9953dfbec7 100644 --- a/crates/oxc_formatter/src/utils/jsx.rs +++ b/crates/oxc_formatter/src/utils/jsx.rs @@ -90,13 +90,7 @@ pub fn get_wrap_state(parent: &AstNodes<'_>) -> WrapState { AstNodes::ExpressionStatement(stmt) => { // `() =>
` // ^^^^^^^^^^^ - if let AstNodes::FunctionBody(body) = stmt.parent - && matches!(body.parent, AstNodes::ArrowFunctionExpression(arrow) if arrow.expression) - { - WrapState::WrapOnBreak - } else { - WrapState::NoWrap - } + if stmt.is_arrow_function_body() { WrapState::WrapOnBreak } else { WrapState::NoWrap } } AstNodes::ComputedMemberExpression(member) => { if member.optional { diff --git a/crates/oxc_formatter/src/utils/member_chain/mod.rs b/crates/oxc_formatter/src/utils/member_chain/mod.rs index c366aecbc80b7..3604a1dd6202d 100644 --- a/crates/oxc_formatter/src/utils/member_chain/mod.rs +++ b/crates/oxc_formatter/src/utils/member_chain/mod.rs @@ -95,13 +95,7 @@ impl<'a, 'b> MemberChain<'a, 'b> { has_computed_property || is_factory(&identifier.name) || // If an identifier has a name that is shorter than the tab with, then we join it with the "head" - (matches!(parent, AstNodes::ExpressionStatement(stmt) if { - if let AstNodes::ArrowFunctionExpression(arrow) = stmt.grand_parent() { - !arrow.expression - } else { - true - } - }) + (matches!(parent, AstNodes::ExpressionStatement(stmt) if !stmt.is_arrow_function_body()) && has_short_name(&identifier.name, f.options().indent_width.value())) } else { matches!(node.as_ref(), Expression::ThisExpression(_)) diff --git a/crates/oxc_formatter/src/write/binary_like_expression.rs b/crates/oxc_formatter/src/write/binary_like_expression.rs index 909aee5013095..d58ea0097cb59 100644 --- a/crates/oxc_formatter/src/write/binary_like_expression.rs +++ b/crates/oxc_formatter/src/write/binary_like_expression.rs @@ -158,13 +158,7 @@ impl<'a, 'b> BinaryLikeExpression<'a, 'b> { AstNodes::JSXExpressionContainer(container) => { matches!(container.parent, AstNodes::JSXAttribute(_)) } - AstNodes::ExpressionStatement(statement) => { - if let AstNodes::FunctionBody(arrow) = statement.parent { - arrow.span == self.span() - } else { - false - } - } + AstNodes::ExpressionStatement(statement) => statement.is_arrow_function_body(), AstNodes::ConditionalExpression(conditional) => { !matches!( parent.parent(), diff --git a/crates/oxc_formatter/src/write/sequence_expression.rs b/crates/oxc_formatter/src/write/sequence_expression.rs index 174701f56588e..009ce1b528105 100644 --- a/crates/oxc_formatter/src/write/sequence_expression.rs +++ b/crates/oxc_formatter/src/write/sequence_expression.rs @@ -32,8 +32,8 @@ impl<'a> FormatWrite<'a> for AstNode<'a, SequenceExpression<'a>> { }); if matches!(self.parent, AstNodes::ForStatement(_)) - || (matches!(self.parent, AstNodes::ExpressionStatement(statement) if - !matches!(statement.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression))) + || (matches!(self.parent, AstNodes::ExpressionStatement(statement) + if !statement.is_arrow_function_body())) { write!(f, [indent(&rest)]) } else {