Skip to content

Commit

Permalink
Merge pull request #2767 from AppAppWorks/detect-circular-expansion
Browse files Browse the repository at this point in the history
Detect circular macro expansion
  • Loading branch information
ahoppen authored Aug 9, 2024
2 parents 2256eaa + 6edc2e1 commit 8a9445f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 12 deletions.
4 changes: 4 additions & 0 deletions Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum MacroExpansionError: Error, CustomStringConvertible {
case noFreestandingMacroRoles(Macro.Type)
case moreThanOneBodyMacro
case preambleWithoutBody
case recursiveExpansion(any Macro.Type)

var description: String {
switch self {
Expand Down Expand Up @@ -92,6 +93,9 @@ enum MacroExpansionError: Error, CustomStringConvertible {

case .preambleWithoutBody:
return "preamble macro cannot be applied to a function with no body"

case .recursiveExpansion(let type):
return "recursive expansion of macro '\(type)'"
}
}
}
Expand Down
82 changes: 70 additions & 12 deletions Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,12 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
/// added to top-level 'CodeBlockItemList'.
var extensions: [CodeBlockItemSyntax] = []

/// Stores the types of the freestanding macros that are currently expanding.
///
/// As macros are expanded by DFS, `expandingFreestandingMacros` always represent the expansion path starting from
/// the root macro node to the last macro node currently expanding.
var expandingFreestandingMacros: [any Macro.Type] = []

init(
macroSystem: MacroSystem,
contextGenerator: @escaping (Syntax) -> Context,
Expand All @@ -684,7 +690,7 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
}

override func visitAny(_ node: Syntax) -> Syntax? {
if skipVisitAnyHandling.contains(node) {
guard !skipVisitAnyHandling.contains(node) else {
return nil
}

Expand All @@ -693,8 +699,10 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
// position are handled by 'visit(_:CodeBlockItemListSyntax)'.
// Only expression expansions inside other syntax nodes is handled here.
switch expandExpr(node: node) {
case .success(let expanded):
return Syntax(visit(expanded))
case .success(let expansion):
return expansion.withExpandedNode { expandedNode in
Syntax(visit(expandedNode))
}
case .failure:
return Syntax(node)
case .notAMacro:
Expand Down Expand Up @@ -795,9 +803,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
func addResult(_ node: CodeBlockItemSyntax) {
// Expand freestanding macro.
switch expandCodeBlockItem(node: node) {
case .success(let expanded):
for item in expanded {
addResult(item)
case .success(let expansion):
expansion.withExpandedNode { expandedNode in
for item in expandedNode {
addResult(item)
}
}
return
case .failure:
Expand Down Expand Up @@ -840,9 +850,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
func addResult(_ node: MemberBlockItemSyntax) {
// Expand freestanding macro.
switch expandMemberDecl(node: node) {
case .success(let expanded):
for item in expanded {
addResult(item)
case .success(let expansion):
expansion.withExpandedNode { expandedNode in
for item in expandedNode {
addResult(item)
}
}
return
case .failure:
Expand Down Expand Up @@ -1218,9 +1230,36 @@ extension MacroApplication {
// MARK: Freestanding macro expansion

extension MacroApplication {
/// Encapsulates an expanded node, the type of the macro from which the node was expanded, and the macro application,
/// such that recursive macro expansion can be consistently detected.
struct MacroExpansion<ResultType> {
private let expandedNode: ResultType
private let macro: any Macro.Type
private unowned let macroApplication: MacroApplication

fileprivate init(expandedNode: ResultType, macro: any Macro.Type, macroApplication: MacroApplication) {
self.expandedNode = expandedNode
self.macro = macro
self.macroApplication = macroApplication
}

/// Invokes the given closure with the node resulting from a macro expansion.
///
/// This method inserts a pair of push and pop operations immediately around the invocation of `body` to maintain
/// an exact stack of expanding freestanding macros to detect recursive macro expansion. Callers should perform any
/// further macro expansion on `expanded` only within the scope of `body`.
func withExpandedNode<T>(_ body: (_ expandedNode: ResultType) throws -> T) rethrows -> T {
macroApplication.expandingFreestandingMacros.append(macro)
defer {
macroApplication.expandingFreestandingMacros.removeLast()
}
return try body(expandedNode)
}
}

enum MacroExpansionResult<ResultType> {
/// Expansion of the macro succeeded.
case success(ResultType)
case success(expansion: MacroExpansion<ResultType>)

/// Macro system found the macro to expand but running the expansion threw
/// an error and thus no expansion result exists.
Expand All @@ -1230,18 +1269,37 @@ extension MacroApplication {
case notAMacro
}

/// Expands the given freestanding macro node into a syntax node by invoking the given closure.
///
/// Any error thrown by `expandMacro` and circular expansion error will be added to diagnostics.
///
/// - Parameters:
/// - node: The freestanding macro node to be expanded.
/// - expandMacro: The closure that expands the given macro type and macro node into a syntax node.
///
/// - Returns:
/// Returns `.notAMacro` if `node` is `nil` or `node.macroName` isn't registered with any macro type.
/// Returns `.failure` if `expandMacro` throws an error or returns `nil`, or recursive expansion is detected.
/// Returns `.success` otherwise.
private func expandFreestandingMacro<ExpandedMacroType: SyntaxProtocol>(
_ node: (any FreestandingMacroExpansionSyntax)?,
expandMacro: (_ macro: Macro.Type, _ node: any FreestandingMacroExpansionSyntax) throws -> ExpandedMacroType?
expandMacro: (_ macro: any Macro.Type, _ node: any FreestandingMacroExpansionSyntax) throws -> ExpandedMacroType?
) -> MacroExpansionResult<ExpandedMacroType> {
guard let node,
let macro = macroSystem.lookup(node.macroName.text)?.type
else {
return .notAMacro
}

do {
guard !expandingFreestandingMacros.contains(where: { $0 == macro }) else {
// We may think of any ongoing macro expansion as a tree in which macro types being expanded are nodes.
// Any macro type being expanded more than once will create a cycle which the compiler as of now doesn't allow.
throw MacroExpansionError.recursiveExpansion(macro)
}

if let expanded = try expandMacro(macro, node) {
return .success(expanded)
return .success(expansion: MacroExpansion(expandedNode: expanded, macro: macro, macroApplication: self))
} else {
return .failure
}
Expand Down
67 changes: 67 additions & 0 deletions Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,46 @@ fileprivate struct StringifyMacro: ExpressionMacro {
}
}

private struct InfiniteRecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
if let i = node.arguments.first?.expression.as(IntegerLiteralExprSyntax.self)?.representedLiteralValue {
return "\(raw: i) + #infiniteRecursion(i: \(raw: i + 1))"
} else {
return "#nested1"
}
}
}

private struct Nested1RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested2, #nested3, #infiniteRecursion(i: 1), #infiniteRecursion)"
}
}

private struct Nested2RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested3, #nested3)"
}
}

private struct Nested3RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"0"
}
}

final class ExpressionMacroTests: XCTestCase {
private let indentationWidth: Trivia = .spaces(2)

Expand Down Expand Up @@ -292,4 +332,31 @@ final class ExpressionMacroTests: XCTestCase {
macros: ["test": DiagnoseFirstArgument.self]
)
}

func testDetectCircularExpansion() {
assertMacroExpansion(
"#nested1",
expandedSource: "((0, 0), 0, 1 + #infiniteRecursion(i: 2), #nested1)",
diagnostics: [
DiagnosticSpec(
message:
"recursive expansion of macro 'InfiniteRecursionMacro'",
line: 1,
column: 5
),
DiagnosticSpec(
message:
"recursive expansion of macro 'Nested1RecursionMacro'",
line: 1,
column: 1
),
],
macros: [
"nested1": Nested1RecursionMacro.self,
"nested2": Nested2RecursionMacro.self,
"nested3": Nested3RecursionMacro.self,
"infiniteRecursion": InfiniteRecursionMacro.self,
]
)
}
}

0 comments on commit 8a9445f

Please sign in to comment.