Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/oxc_ast/src/ast_impl/js.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,16 @@ impl<'a> Function<'a> {
}
}

// FIXME: This is a workaround for we can't get current address by `TraverseCtx`,
// we will remove this once we support `TraverseCtx::current_address`.
// See: <https://github.com/oxc-project/oxc/pull/6881#discussion_r1816560516>
impl GetAddress for Function<'_> {
#[inline]
fn address(&self) -> Address {
Address::from_ptr(self)
}
}

impl<'a> FormalParameters<'a> {
/// Number of parameters bound in this parameter list.
pub fn parameters_count(&self) -> usize {
Expand Down
132 changes: 49 additions & 83 deletions crates/oxc_transformer/src/jsx/refresh.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::iter::once;

use base64::prelude::{Engine, BASE64_STANDARD};
use rustc_hash::FxHashMap;
use sha1::{Digest, Sha1};

use oxc_allocator::CloneIn;
use oxc_allocator::{CloneIn, GetAddress};
use oxc_ast::{ast::*, match_expression, AstBuilder, NONE};
use oxc_semantic::{Reference, ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags, SymbolId};
use oxc_semantic::{Reference, ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags};
use oxc_span::{Atom, GetSpan, SPAN};
use oxc_syntax::operator::AssignmentOperator;
use oxc_traverse::{Ancestor, BoundIdentifier, Traverse, TraverseCtx};
Expand Down Expand Up @@ -107,7 +105,6 @@ pub struct ReactRefresh<'a, 'ctx> {
/// Used to wrap call expression with signature.
/// (eg: hoc(() => {}) -> _s1(hoc(_s1(() => {}))))
last_signature: Option<(BindingIdentifier<'a>, oxc_allocator::Vec<'a, Argument<'a>>)>,
extra_statements: FxHashMap<SymbolId, oxc_allocator::Vec<'a, Statement<'a>>>,
// (function_scope_id, (hook_name, hook_key, custom_hook_callee)
hook_calls: FxHashMap<ScopeId, Vec<(Atom<'a>, Atom<'a>)>>,
non_builtin_hooks_callee: FxHashMap<ScopeId, Vec<Option<Expression<'a>>>>,
Expand All @@ -127,7 +124,6 @@ impl<'a, 'ctx> ReactRefresh<'a, 'ctx> {
registrations: Vec::default(),
ctx,
last_signature: None,
extra_statements: FxHashMap::default(),
hook_calls: FxHashMap::default(),
non_builtin_hooks_callee: FxHashMap::default(),
}
Expand Down Expand Up @@ -196,30 +192,18 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
stmts: &mut oxc_allocator::Vec<'a, Statement<'a>>,
ctx: &mut TraverseCtx<'a>,
) {
// TODO: check is there any function declaration

let mut new_stmts = ctx.ast.vec_with_capacity(stmts.len() + 1);

let declarations = self.signature_declarator_items.pop().unwrap();
if !declarations.is_empty() {
new_stmts.push(Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
declarations,
false,
)));
stmts.insert(
0,
Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
declarations,
false,
)),
);
}
new_stmts.extend(stmts.drain(..).flat_map(move |stmt| {
let symbol_ids = get_symbol_id_from_function_and_declarator(&stmt);
let extra_stmts = symbol_ids
.into_iter()
.filter_map(|symbol_id| self.extra_statements.remove(&symbol_id))
.flatten()
.collect::<Vec<_>>();
once(stmt).chain(extra_stmts)
}));

*stmts = new_stmts;
}

fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) {
Expand Down Expand Up @@ -268,7 +252,6 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {

let first_argument = Argument::from(id_binding.create_read_expression(ctx));
arguments.insert(0, first_argument);

let statement = ctx.ast.statement_expression(
SPAN,
ctx.ast.expression_call(
Expand All @@ -279,10 +262,27 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
false,
),
);
self.extra_statements
.entry(id_binding.symbol_id)
.or_insert(ctx.ast.vec())
.push(statement);

// Get the address of the statement containing this `VariableDeclarator`
#[allow(clippy::single_match_else)]
let address = match ctx.ancestor(2) {
// For `export const Foo = () => {}`
// which is a `VariableDeclaration` inside a `Statement::ExportNamedDeclaration`
Ancestor::ExportNamedDeclarationDeclaration(export_decl) => {
export_decl.address()
}
// Otherwise just a `const Foo = () => {}`
// which is a `Statement::VariableDeclaration`
_ => {
let var_decl = ctx.ancestor(1);
debug_assert!(matches!(
var_decl,
Ancestor::VariableDeclarationDeclarations(_)
));
var_decl.address()
}
};
self.ctx.statement_injector.insert_after(&address, statement);
return;
}
}
Expand Down Expand Up @@ -334,18 +334,23 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
arguments.insert(0, Argument::from(id_binding.create_read_expression(ctx)));

let binding = BoundIdentifier::from_binding_ident(&binding_identifier);
self.extra_statements.entry(id_binding.symbol_id).or_insert(ctx.ast.vec()).push(
ctx.ast.statement_expression(
SPAN,
ctx.ast.expression_call(
SPAN,
binding.create_read_expression(ctx),
NONE,
arguments,
false,
),
),
);
let callee = binding.create_read_expression(ctx);
let expr = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false);
let statement = ctx.ast.statement_expression(SPAN, expr);

// Get the address of the statement containing this `FunctionDeclaration`
let address = match ctx.parent() {
// For `export function Foo() {}`
// which is a `Statement::ExportNamedDeclaration`
Ancestor::ExportNamedDeclarationDeclaration(decl) => decl.address(),
// For `export default function() {}`
// which is a `Statement::ExportDefaultDeclaration`
Ancestor::ExportDefaultDeclarationDeclaration(decl) => decl.address(),
// Otherwise just a `function Foo() {}`
// which is a `Statement::FunctionDeclaration`
_ => func.address(),
};
self.ctx.statement_injector.insert_after(&address, statement);
}

fn enter_call_expression(
Expand Down Expand Up @@ -898,42 +903,3 @@ fn is_builtin_hook(hook_name: &str) -> bool {
"useOptimistic"
)
}

fn get_symbol_id_from_function_and_declarator(stmt: &Statement<'_>) -> Vec<SymbolId> {
let mut symbol_ids = vec![];
match stmt {
Statement::FunctionDeclaration(ref func) => {
if !func.is_typescript_syntax() {
symbol_ids.push(func.symbol_id().unwrap());
}
}
Statement::VariableDeclaration(ref decl) => {
symbol_ids.extend(decl.declarations.iter().filter_map(|decl| {
decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get())
}));
}
Statement::ExportNamedDeclaration(ref export_decl) => {
if let Some(Declaration::FunctionDeclaration(func)) = &export_decl.declaration {
if !func.is_typescript_syntax() {
symbol_ids.push(func.symbol_id().unwrap());
}
} else if let Some(Declaration::VariableDeclaration(decl)) = &export_decl.declaration {
symbol_ids.extend(decl.declarations.iter().filter_map(|decl| {
decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get())
}));
}
}
Statement::ExportDefaultDeclaration(ref export_decl) => {
if let ExportDefaultDeclarationKind::FunctionDeclaration(func) =
&export_decl.declaration
{
if let Some(id) = func.symbol_id() {
symbol_ids.push(id);
}
}
}
_ => {}
};

symbol_ids
}