diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift index 82500b42703..59c49c39aed 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift @@ -63,6 +63,7 @@ enum MacroExpansionError: Error, CustomStringConvertible { case noFreestandingMacroRoles(Macro.Type) case moreThanOneBodyMacro case preambleWithoutBody + case circularExpansion(Macro.Type, any FreestandingMacroExpansionSyntax) var description: String { switch self { @@ -92,6 +93,9 @@ enum MacroExpansionError: Error, CustomStringConvertible { case .preambleWithoutBody: return "preamble macro cannot be applied to a function with no body" + + case .circularExpansion(let type, let syntax): + return "circular expansion detected: '\(syntax)' with macro implementation type '\(type)'" } } } diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift b/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift index 910e69553c6..bb3dc4011d8 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift @@ -667,6 +667,9 @@ private class MacroApplication: SyntaxRewriter { /// added to top-level 'CodeBlockItemList'. var extensions: [CodeBlockItemSyntax] = [] + /// Stores the types of the freestanding macros that are currently expanding. + var expandingFreestandingMacros: [any Macro.Type] = [] + init( macroSystem: MacroSystem, contextGenerator: @escaping (Syntax) -> Context, @@ -687,6 +690,11 @@ private class MacroApplication: SyntaxRewriter { return nil } + let macroCount = expandingFreestandingMacros.count + defer { + expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount) + } + // Expand 'MacroExpansionExpr'. // Note that 'MacroExpansionExpr'/'MacroExpansionExprDecl' at code item // position are handled by 'visit(_:CodeBlockItemListSyntax)'. @@ -792,6 +800,11 @@ private class MacroApplication: SyntaxRewriter { override func visit(_ node: CodeBlockItemListSyntax) -> CodeBlockItemListSyntax { var newItems: [CodeBlockItemSyntax] = [] func addResult(_ node: CodeBlockItemSyntax) { + let macroCount = expandingFreestandingMacros.count + defer { + expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount) + } + // Expand freestanding macro. switch expandCodeBlockItem(node: node) { case .success(let expanded): @@ -837,6 +850,11 @@ private class MacroApplication: SyntaxRewriter { var newItems: [MemberBlockItemSyntax] = [] func addResult(_ node: MemberBlockItemSyntax) { + let macroCount = expandingFreestandingMacros.count + defer { + expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount) + } + // Expand freestanding macro. switch expandMemberDecl(node: node) { case .success(let expanded): @@ -1226,7 +1244,13 @@ extension MacroApplication { else { return .notAMacro } + do { + guard expandingFreestandingMacros.allSatisfy({ $0 != macro }) else { + throw MacroExpansionError.circularExpansion(macro, node) + } + expandingFreestandingMacros.append(macro) + if let expanded = try expandMacro(macro, node) { return .success(expanded) } else { diff --git a/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift index 365a6b01383..03558ce5686 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift @@ -37,6 +37,46 @@ fileprivate struct StringifyMacro: ExpressionMacro { } } +private struct InfiniteRecursionMacro: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + if let i = node.arguments.first?.expression.as(IntegerLiteralExprSyntax.self)?.representedLiteralValue { + return "\(raw: i) + #infiniteRecursion(i: \(raw: i + 1))" + } else { + return "#nested1" + } + } +} + +private struct Nested1RecursionMacro: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + "(#nested2, #nested3, #infiniteRecursion(i: 1), #infiniteRecursion)" + } +} + +private struct Nested2RecursionMacro: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + "(#nested3, #nested3)" + } +} + +private struct Nested3RecursionMacro: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + "0" + } +} + final class ExpressionMacroTests: XCTestCase { private let indentationWidth: Trivia = .spaces(2) @@ -292,4 +332,31 @@ final class ExpressionMacroTests: XCTestCase { macros: ["test": DiagnoseFirstArgument.self] ) } + + func testDetectCircularExpansion() { + assertMacroExpansion( + "#nested1", + expandedSource: "((0, 0), 0, 1 + #infiniteRecursion(i: 2), #nested1)", + diagnostics: [ + DiagnosticSpec( + message: + "circular expansion detected: '#infiniteRecursion(i: 2)' with macro implementation type 'InfiniteRecursionMacro'", + line: 1, + column: 5 + ), + DiagnosticSpec( + message: + "circular expansion detected: '#nested1' with macro implementation type 'Nested1RecursionMacro'", + line: 1, + column: 1 + ), + ], + macros: [ + "nested1": Nested1RecursionMacro.self, + "nested2": Nested2RecursionMacro.self, + "nested3": Nested3RecursionMacro.self, + "infiniteRecursion": InfiniteRecursionMacro.self, + ] + ) + } }