diff --git a/tooling/nargo_cli/src/cli/expand_cmd/printer/hir.rs b/tooling/nargo_cli/src/cli/expand_cmd/printer/hir.rs index f1c51a634b0..e5cf1e96ea0 100644 --- a/tooling/nargo_cli/src/cli/expand_cmd/printer/hir.rs +++ b/tooling/nargo_cli/src/cli/expand_cmd/printer/hir.rs @@ -231,8 +231,7 @@ impl ItemPrinter<'_, '_, '_> { self.show_quoted(&tokens.0); } HirExpression::Unsafe(hir_block_expression) => { - self.push_str("// Safety: comment added by `nargo expand`\n"); - self.write_indent(); + // The safety comment was already outputted for the enclosing statement self.push_str("unsafe "); self.show_hir_block_expression(hir_block_expression); } @@ -486,31 +485,24 @@ impl ItemPrinter<'_, '_, '_> { } fn show_hir_statement(&mut self, statement: HirStatement) { + // A safety comment can be put before a statement and it applies to any `unsafe` + // expression inside it. Here we check if the statement has `unsafe` in it and + // put a safety comment right before it. When printing an `Unsafe` expression + // we'll never include a safety comment at that point. + let has_unsafe = self.statement_has_unsafe(&statement); + if has_unsafe { + self.push_str("// Safety: comment added by `nargo expand`\n"); + self.write_indent(); + } + match statement { HirStatement::Let(hir_let_statement) => { - // If this is `let ... = unsafe { }` then show the unsafe comment on top of `let` - if let HirExpression::Unsafe(_) = - self.interner.expression(&hir_let_statement.expression) - { - self.push_str("// Safety: comment added by `nargo expand`\n"); - self.write_indent(); - } - self.push_str("let "); self.show_hir_pattern(hir_let_statement.pattern); self.push_str(": "); self.show_type(&hir_let_statement.r#type); self.push_str(" = "); - - if let HirExpression::Unsafe(block_expression) = - self.interner.expression(&hir_let_statement.expression) - { - self.push_str("unsafe "); - self.show_hir_block_expression(block_expression); - } else { - self.show_hir_expression_id(hir_let_statement.expression); - } - + self.show_hir_expression_id(hir_let_statement.expression); self.push(';'); } HirStatement::Assign(hir_assign_statement) => { @@ -738,6 +730,146 @@ impl ItemPrinter<'_, '_, '_> { } } } + + fn statement_id_has_unsafe(&self, stmt_id: StmtId) -> bool { + let statement = self.interner.statement(&stmt_id); + self.statement_has_unsafe(&statement) + } + + fn statement_has_unsafe(&self, statement: &HirStatement) -> bool { + match statement { + HirStatement::Let(hir_let_statement) => { + self.expression_id_has_unsafe(hir_let_statement.expression) + } + HirStatement::Assign(hir_assign_statement) => { + self.expression_id_has_unsafe(hir_assign_statement.expression) + } + HirStatement::For(hir_for_statement) => { + // We don't check the block, as the block consists of statements and we + // can put the safety comment on top of the ones that have unsafe + self.expression_id_has_unsafe(hir_for_statement.start_range) + || self.expression_id_has_unsafe(hir_for_statement.end_range) + } + HirStatement::Loop(expr_id) => self.expression_id_has_unsafe(*expr_id), + HirStatement::While(expr_id, expr_id2) => { + self.expression_id_has_unsafe(*expr_id) || self.expression_id_has_unsafe(*expr_id2) + } + HirStatement::Break => false, + HirStatement::Continue => false, + HirStatement::Expression(expr_id) => self.expression_id_has_unsafe(*expr_id), + HirStatement::Semi(expr_id) => self.expression_id_has_unsafe(*expr_id), + HirStatement::Comptime(stmt_id) => self.statement_id_has_unsafe(*stmt_id), + HirStatement::Error => false, + } + } + + fn expression_id_has_unsafe(&self, expr_id: ExprId) -> bool { + let hir_expr = self.interner.expression(&expr_id); + self.expression_has_unsafe(hir_expr) + } + + fn expression_has_unsafe(&self, expr: HirExpression) -> bool { + match expr { + HirExpression::Ident(..) => false, + HirExpression::Literal(hir_literal) => match hir_literal { + HirLiteral::Array(hir_array_literal) | HirLiteral::Slice(hir_array_literal) => { + match hir_array_literal { + HirArrayLiteral::Standard(expr_ids) => { + expr_ids.iter().any(|expr_id| self.expression_id_has_unsafe(*expr_id)) + } + HirArrayLiteral::Repeated { repeated_element, length: _ } => { + self.expression_id_has_unsafe(repeated_element) + } + } + } + HirLiteral::FmtStr(_, expr_ids, _) => { + expr_ids.iter().any(|expr_id| self.expression_id_has_unsafe(*expr_id)) + } + HirLiteral::Bool(_) + | HirLiteral::Integer(..) + | HirLiteral::Str(_) + | HirLiteral::Unit => false, + }, + HirExpression::Block(_) => { + // A block consists of statements so if any of those have `unsafe`, those + // should have the safety comment, not this wrapping statement + false + } + HirExpression::Prefix(hir_prefix_expression) => { + self.expression_id_has_unsafe(hir_prefix_expression.rhs) + } + HirExpression::Infix(hir_infix_expression) => { + self.expression_id_has_unsafe(hir_infix_expression.lhs) + || self.expression_id_has_unsafe(hir_infix_expression.rhs) + } + HirExpression::Index(hir_index_expression) => { + self.expression_id_has_unsafe(hir_index_expression.collection) + || self.expression_id_has_unsafe(hir_index_expression.index) + } + HirExpression::Constructor(hir_constructor_expression) => hir_constructor_expression + .fields + .iter() + .any(|(_, expr_id)| self.expression_id_has_unsafe(*expr_id)), + HirExpression::EnumConstructor(hir_enum_constructor_expression) => { + hir_enum_constructor_expression + .arguments + .iter() + .any(|expr_id| self.expression_id_has_unsafe(*expr_id)) + } + HirExpression::MemberAccess(hir_member_access) => { + self.expression_id_has_unsafe(hir_member_access.lhs) + } + HirExpression::Call(hir_call_expression) => { + self.expression_id_has_unsafe(hir_call_expression.func) + || hir_call_expression + .arguments + .iter() + .any(|expr_id| self.expression_id_has_unsafe(*expr_id)) + } + HirExpression::Constrain(hir_constrain_expression) => { + self.expression_id_has_unsafe(hir_constrain_expression.0) + || hir_constrain_expression + .2 + .is_some_and(|expr_id| self.expression_id_has_unsafe(expr_id)) + } + HirExpression::Cast(hir_cast_expression) => { + self.expression_id_has_unsafe(hir_cast_expression.lhs) + } + HirExpression::If(hir_if_expression) => { + self.expression_id_has_unsafe(hir_if_expression.condition) + || self.expression_id_has_unsafe(hir_if_expression.consequence) + || hir_if_expression + .alternative + .is_some_and(|expr_id| self.expression_id_has_unsafe(expr_id)) + } + HirExpression::Match(hir_match) => self.hir_match_has_unsafe(&hir_match), + HirExpression::Tuple(expr_ids) => { + expr_ids.iter().any(|expr_id| self.expression_id_has_unsafe(*expr_id)) + } + HirExpression::Lambda(hir_lambda) => self.expression_id_has_unsafe(hir_lambda.body), + HirExpression::Quote(..) | HirExpression::Unquote(..) => false, + HirExpression::Unsafe(..) => true, + HirExpression::Error => false, + } + } + + fn hir_match_has_unsafe(&self, hir_match: &HirMatch) -> bool { + match hir_match { + HirMatch::Success(expr_id) => self.expression_id_has_unsafe(*expr_id), + HirMatch::Failure { .. } => false, + HirMatch::Guard { cond, body, otherwise } => { + self.expression_id_has_unsafe(*cond) + || self.expression_id_has_unsafe(*body) + || self.hir_match_has_unsafe(otherwise) + } + HirMatch::Switch(_, cases, hir_match) => { + cases.iter().any(|case| self.hir_match_has_unsafe(&case.body)) + || hir_match + .as_ref() + .is_some_and(|hir_match| self.hir_match_has_unsafe(hir_match)) + } + } + } } fn hir_expression_needs_parentheses(hir_expr: &HirExpression) -> bool { diff --git a/tooling/nargo_cli/tests/snapshots/expand/execution_success/uhashmap/execute__tests__expanded.snap b/tooling/nargo_cli/tests/snapshots/expand/execution_success/uhashmap/execute__tests__expanded.snap index 2d9fd8f916d..0cd4347f964 100644 --- a/tooling/nargo_cli/tests/snapshots/expand/execution_success/uhashmap/execute__tests__expanded.snap +++ b/tooling/nargo_cli/tests/snapshots/expand/execution_success/uhashmap/execute__tests__expanded.snap @@ -266,9 +266,8 @@ fn entries_examples(map: UHashMap {value}"); } }