diff --git a/crates/oxc_transformer/src/common/mod.rs b/crates/oxc_transformer/src/common/mod.rs index 35ff1567063a4..c0d64beb470bf 100644 --- a/crates/oxc_transformer/src/common/mod.rs +++ b/crates/oxc_transformer/src/common/mod.rs @@ -65,6 +65,16 @@ impl<'a> Traverse<'a> for Common<'a, '_> { self.statement_injector.exit_statements(stmts, ctx); } + #[inline] + fn enter_statement(&mut self, stmt: &mut Statement<'a>, ctx: &mut TraverseCtx<'a>) { + self.statement_injector.enter_statement(stmt, ctx); + } + + #[inline] + fn exit_statement(&mut self, stmt: &mut Statement<'a>, ctx: &mut TraverseCtx<'a>) { + self.statement_injector.exit_statement(stmt, ctx); + } + fn enter_function(&mut self, func: &mut Function<'a>, ctx: &mut TraverseCtx<'a>) { self.arrow_function_converter.enter_function(func, ctx); } diff --git a/crates/oxc_transformer/src/common/statement_injector.rs b/crates/oxc_transformer/src/common/statement_injector.rs index 1374555ab21dd..2e7e66568c3f2 100644 --- a/crates/oxc_transformer/src/common/statement_injector.rs +++ b/crates/oxc_transformer/src/common/statement_injector.rs @@ -12,7 +12,7 @@ //! self.ctx.statement_injector.insert_many_after(address, statements); //! ``` -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use rustc_hash::FxHashMap; @@ -34,6 +34,14 @@ impl<'a, 'ctx> StatementInjector<'a, 'ctx> { } impl<'a> Traverse<'a> for StatementInjector<'a, '_> { + fn enter_statement(&mut self, stmt: &mut Statement<'a>, _ctx: &mut TraverseCtx<'a>) { + self.ctx.statement_injector.set_current_statement_address(stmt); + } + + fn exit_statement(&mut self, stmt: &mut Statement<'a>, _ctx: &mut TraverseCtx<'a>) { + self.ctx.statement_injector.set_current_statement_address(stmt); + } + fn exit_statements( &mut self, statements: &mut ArenaVec<'a, Statement<'a>>, @@ -57,6 +65,7 @@ struct AdjacentStatement<'a> { /// Store for statements to be added to the statements. pub struct StatementInjectorStore<'a> { + current_statement_address: Cell
, insertions: RefCell>>>, } @@ -64,7 +73,10 @@ pub struct StatementInjectorStore<'a> { impl StatementInjectorStore<'_> { /// Create new `StatementInjectorStore`. pub fn new() -> Self { - Self { insertions: RefCell::new(FxHashMap::default()) } + Self { + current_statement_address: Cell::new(Address::DUMMY), + insertions: RefCell::new(FxHashMap::default()), + } } } @@ -148,6 +160,44 @@ impl<'a> StatementInjectorStore<'a> { stmts.into_iter().map(|stmt| AdjacentStatement { stmt, direction: Direction::After }), ); } + + /// Add a statement to be inserted immediately before the current statement. + #[expect(unused)] + #[inline] + pub fn insert_before_current_statement(&self, stmt: Statement<'a>) { + debug_assert_ne!(self.current_statement_address.get(), Address::DUMMY); + self.insert_before_address(self.current_statement_address.get(), stmt); + } + + /// Add a statement to be inserted immediately after the current statement. + #[expect(unused)] + #[inline] + pub fn insert_after_current_statement(&self, stmt: Statement<'a>) { + debug_assert_ne!(self.current_statement_address.get(), Address::DUMMY); + self.insert_after_address(self.current_statement_address.get(), stmt); + } + + /// Add multiple statements to be inserted immediately before the current statement. + #[expect(unused)] + #[inline] + pub fn insert_many_before_current_statement(&self, stmts: S) + where + S: IntoIterator>, + { + debug_assert_ne!(self.current_statement_address.get(), Address::DUMMY); + self.insert_many_before_address(self.current_statement_address.get(), stmts); + } + + /// Add multiple statements to be inserted immediately after the current statement. + #[expect(unused)] + #[inline] + pub fn insert_many_after_current_statement(&self, stmts: S) + where + S: IntoIterator>, + { + debug_assert_ne!(self.current_statement_address.get(), Address::DUMMY); + self.insert_many_after_address(self.current_statement_address.get(), stmts); + } } // Internal methods @@ -167,6 +217,7 @@ impl<'a> StatementInjectorStore<'a> { .iter() .filter_map(|s| insertions.get(&s.address()).map(Vec::len)) .sum::(); + if new_statement_count == 0 { return; } @@ -195,4 +246,9 @@ impl<'a> StatementInjectorStore<'a> { *statements = new_statements; } + + #[inline] + fn set_current_statement_address(&self, stmt: &Statement<'a>) { + self.current_statement_address.set(stmt.address()); + } } diff --git a/crates/oxc_transformer/src/lib.rs b/crates/oxc_transformer/src/lib.rs index 048a15724195e..5547004c4fec2 100644 --- a/crates/oxc_transformer/src/lib.rs +++ b/crates/oxc_transformer/src/lib.rs @@ -530,6 +530,7 @@ impl<'a> Traverse<'a> for TransformerImpl<'a, '_> { } fn exit_statement(&mut self, stmt: &mut Statement<'a>, ctx: &mut TraverseCtx<'a>) { + self.common.exit_statement(stmt, ctx); if let Some(typescript) = self.x0_typescript.as_mut() { typescript.exit_statement(stmt, ctx); } @@ -549,6 +550,7 @@ impl<'a> Traverse<'a> for TransformerImpl<'a, '_> { } fn enter_statement(&mut self, stmt: &mut Statement<'a>, ctx: &mut TraverseCtx<'a>) { + self.common.enter_statement(stmt, ctx); if let Some(typescript) = self.x0_typescript.as_mut() { typescript.enter_statement(stmt, ctx); }