From dfafbd065064569d9fb703ee643af397499a075e Mon Sep 17 00:00:00 2001 From: Alex Skorulis Date: Tue, 10 Dec 2024 15:58:10 +1100 Subject: [PATCH] Create Resolvable macro --- Package.swift | 24 ++ Sources/KnitCodeGen/TypeNamer.swift | 7 +- Sources/KnitMacros/KnitMacros.swift | 6 + .../KnitMacros/MacroPropertyWrappers.swift | 23 ++ .../KnitMacrosPlugin.swift | 13 ++ .../ResolvableMacro.swift | 210 ++++++++++++++++++ Tests/KnitMacrosTests/ResolvableTests.swift | 164 ++++++++++++++ .../SwinjectResolutionTests.swift | 137 ++++++++++++ 8 files changed, 580 insertions(+), 4 deletions(-) create mode 100644 Sources/KnitMacros/KnitMacros.swift create mode 100644 Sources/KnitMacros/MacroPropertyWrappers.swift create mode 100644 Sources/KnitMacrosImplementations/KnitMacrosPlugin.swift create mode 100644 Sources/KnitMacrosImplementations/ResolvableMacro.swift create mode 100644 Tests/KnitMacrosTests/ResolvableTests.swift create mode 100644 Tests/KnitMacrosTests/SwinjectResolutionTests.swift diff --git a/Package.swift b/Package.swift index 7227aa6..0a03d23 100644 --- a/Package.swift +++ b/Package.swift @@ -1,6 +1,7 @@ // swift-tools-version: 5.10 // The swift-tools-version declares the minimum version of Swift required to build this package. +import CompilerPluginSupport import PackageDescription let package = Package( @@ -11,6 +12,7 @@ let package = Package( ], products: [ .library(name: "Knit", targets: ["Knit"]), + .library(name: "KnitMacros", targets: ["KnitMacros"] ), .plugin(name: "KnitBuildPlugin", targets: ["KnitBuildPlugin"]), .executable(name: "knit-cli", targets: ["knit-cli"]), ], @@ -88,6 +90,28 @@ let package = Package( "KnitCodeGen", ] ), + + // MARK: - Macro + .macro( + name: "KnitMacrosImplementations", + dependencies: [ + .product(name: "SwiftSyntaxMacros", package: "swift-syntax"), + .product(name: "SwiftCompilerPlugin", package: "swift-syntax"), + .target(name: "KnitCodeGen"), + ] + ), + .target(name: "KnitMacros", dependencies: ["KnitMacrosImplementations"]), + .testTarget( + name: "KnitMacrosTests", + dependencies: [ + "KnitMacrosImplementations", + .target(name: "KnitMacros"), + .target(name: "KnitCodeGen"), + .target(name: "Swinject"), + .product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"), + ] + ), + ], swiftLanguageVersions: [ // When this SPM package is imported by a Swift 6 toolchain it should still be used in the v5 language mode diff --git a/Sources/KnitCodeGen/TypeNamer.swift b/Sources/KnitCodeGen/TypeNamer.swift index 259af3f..c3d9057 100644 --- a/Sources/KnitCodeGen/TypeNamer.swift +++ b/Sources/KnitCodeGen/TypeNamer.swift @@ -4,7 +4,7 @@ import Foundation -enum TypeNamer { +public enum TypeNamer { /** Creates a name for a given Type signature. @@ -12,7 +12,7 @@ enum TypeNamer { See TypeNamerTests unit tests for examples. */ - static func computedIdentifierName(type: String) -> String { + public static func computedIdentifierName(type: String) -> String { let type = sanitizeType(type: type, keepGenerics: false) let lowercaseIndex = type.firstIndex { $0.isLowercase } if let lowercaseIndex { @@ -24,8 +24,7 @@ enum TypeNamer { } /// Simplifies the type name and removes invalid characters - - static func sanitizeType(type: String, keepGenerics: Bool) -> String { + public static func sanitizeType(type: String, keepGenerics: Bool) -> String { if isClosure(type: type) { // The naming doesn't work for function types, just return closure return "closure" diff --git a/Sources/KnitMacros/KnitMacros.swift b/Sources/KnitMacros/KnitMacros.swift new file mode 100644 index 0000000..a465ce4 --- /dev/null +++ b/Sources/KnitMacros/KnitMacros.swift @@ -0,0 +1,6 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +@attached(peer, names: named(make)) +public macro Resolvable() = #externalMacro(module: "KnitMacrosImplementations", type: "ResolvableMacro") diff --git a/Sources/KnitMacros/MacroPropertyWrappers.swift b/Sources/KnitMacros/MacroPropertyWrappers.swift new file mode 100644 index 0000000..350fc95 --- /dev/null +++ b/Sources/KnitMacros/MacroPropertyWrappers.swift @@ -0,0 +1,23 @@ +// Created by Alex Skorulis on 23/1/2025. + +import Foundation + +/// Defines that the parameter should be resolved using the provided name +/// The property wrapper is only used as a hint to the Resolvable macro and has no effect +@propertyWrapper public struct Named { + public var wrappedValue: Value + + public init(wrappedValue: Value, _ name: String) { + self.wrappedValue = wrappedValue + } +} + +/// Defines that the parameter should not be resolved from the DI graph but should be an argument +/// The property wrapper is only used as a hint to the Resolvable macro and has no effect +@propertyWrapper public struct Argument { + public var wrappedValue: Value + + public init(wrappedValue: Value) { + self.wrappedValue = wrappedValue + } +} diff --git a/Sources/KnitMacrosImplementations/KnitMacrosPlugin.swift b/Sources/KnitMacrosImplementations/KnitMacrosPlugin.swift new file mode 100644 index 0000000..4773f02 --- /dev/null +++ b/Sources/KnitMacrosImplementations/KnitMacrosPlugin.swift @@ -0,0 +1,13 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import SwiftCompilerPlugin +import SwiftSyntaxMacros + +@main +struct MacroFunPlugin: CompilerPlugin { + let providingMacros: [Macro.Type] = [ + ResolvableMacro.self + ] +} diff --git a/Sources/KnitMacrosImplementations/ResolvableMacro.swift b/Sources/KnitMacrosImplementations/ResolvableMacro.swift new file mode 100644 index 0000000..124aaae --- /dev/null +++ b/Sources/KnitMacrosImplementations/ResolvableMacro.swift @@ -0,0 +1,210 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import KnitCodeGen +import SwiftDiagnostics +import SwiftSyntax +import SwiftSyntaxMacros + +public struct ResolvableMacro: PeerMacro { + public static func expansion( + of node: AttributeSyntax, + providingPeersOf declaration: some DeclSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { + guard let resolverTypeArg = node.attributeName.as(IdentifierTypeSyntax.self)?.genericArgumentClause?.arguments.first else { + throw DiagnosticsError( + diagnostics: [.init(node: node, message: Error.missingResolverType)] + ) + } + let resolverType = resolverTypeArg.description + + let parameterClause: FunctionParameterClauseSyntax + let returnType: String + let makeCall: String + if let initDecl = declaration.as(InitializerDeclSyntax.self) { + parameterClause = initDecl.signature.parameterClause + returnType = "Self" + makeCall = ".init" + } else if let funcDecl = declaration.as(FunctionDeclSyntax.self) { + parameterClause = funcDecl.signature.parameterClause + guard let ret = funcDecl.signature.returnClause?.type.as(IdentifierTypeSyntax.self)?.name.text else { + throw DiagnosticsError( + diagnostics: [.init(node: funcDecl, message: Error.missingReturnType)] + ) + } + let isStatic = funcDecl.modifiers.contains { $0.name.text == "static" } + guard isStatic else { + throw DiagnosticsError( + diagnostics: [.init(node: funcDecl, message: Error.nonInitializerOrFunc)] + ) + } + returnType = ret + makeCall = funcDecl.name.text + } else { + throw DiagnosticsError( + diagnostics: [.init(node: node, message: Error.nonInitializerOrFunc)] + ) + } + + let params = try parameterClause.parameters.map { paramSyntax in + let type = try extractType(typeSyntax: paramSyntax.type) + let name = paramSyntax.firstName.text + let hint: ParamHint? = extractHint(paramSyntax: paramSyntax) + + return Param( + name: name, + type: type, + hint: hint, + defaultValue: extractDefault(paramSyntax: paramSyntax) + ) + } + + let paramsResolved = params.map { param in + return param.resolveCall + } + let paramsString = paramsResolved.joined(separator: ",\n") + var makeArguments = ["resolver: \(resolverType)"] + for param in params { + if param.isArgument { + makeArguments.append("\(param.name): \(param.type.name)") + } + } + + let makeArgumentsString = makeArguments.joined(separator: ", ") + + return [ + """ + static func make(\(raw: makeArgumentsString)) -> \(raw: returnType) { + return \(raw: makeCall)( + \(raw: paramsString) + ) + } + """ + ] + } + + private static func extractType(typeSyntax: TypeSyntax) throws -> TypeInformation { + if let type = typeSyntax.as(IdentifierTypeSyntax.self) { + return TypeInformation(name: type.name.text) + } else if let type = typeSyntax.as(AttributedTypeSyntax.self) { + let baseType = try extractType(typeSyntax: type.baseType) + return TypeInformation(name: baseType.name) + } else if let type = typeSyntax.as(FunctionTypeSyntax.self) { + return TypeInformation(name: "(\(type.description))") + } + throw DiagnosticsError( + diagnostics: [.init(node: typeSyntax, message: Error.invalidParamType(typeSyntax.description))] + ) + } + + private static func extractHint(paramSyntax: FunctionParameterSyntax) -> ParamHint? { + guard let type = paramSyntax.type.as(AttributedTypeSyntax.self) else { + return nil + } + for element in type.attributes { + guard case let AttributeListSyntax.Element.attribute(attribute) = element else { + continue + } + let name = attribute.attributeName.description.trimmingCharacters(in: .whitespaces) + if name == "Argument" { + return .argument + } else if name == "Named", + let arguments = attribute.arguments?.as(LabeledExprListSyntax.self), + let firstString = arguments.first?.expression.as(StringLiteralExprSyntax.self)?.textContent + { + return .named(firstString) + } + } + return nil + } + + private static func extractDefault(paramSyntax: FunctionParameterSyntax) -> String? { + guard let defaultValue = paramSyntax.defaultValue else { + return nil + } + return defaultValue.description.replacingOccurrences(of: "= ", with: "") + } + +} + +private extension ResolvableMacro { + struct Param { + let name: String + let type: TypeInformation + let hint: ParamHint? + let defaultValue: String? + + var isArgument: Bool { hint == .argument } + + var resolveCall: String { + let knitCallName = TypeNamer.computedIdentifierName(type: type.name) + if let defaultValue { + return "\(name): \(defaultValue)" + } else if let hint { + switch hint { + case let .named(serviceName): + return "\(name): resolver.\(knitCallName)(name: .\(serviceName))" + case .argument: + return "\(name): \(name)" + } + } else { + return "\(name): resolver.\(knitCallName)()" + } + } + } + + struct TypeInformation { + let name: String + + init(name: String) { + self.name = name + } + } + + enum ParamHint: Equatable { + case argument + case named(String) + } + + private struct HintContainer { + var hints: [String: ParamHint] + } + + enum Error: DiagnosticMessage { + case missingResolverType + case nonInitializerOrFunc + case missingReturnType + case invalidParamType(String) + + var message: String { + switch self { + case .missingResolverType: + return "@Resolvable requires a generic parameter" + case .nonInitializerOrFunc: + return "@Resolvable can only be used on init declarations or static functions" + case let .invalidParamType(string): + return "Unexpected parameter type: \(string)" + case .missingReturnType: + return "Could not identify function return type" + } + } + + var diagnosticID: MessageID { + MessageID(domain: "ResolvableMacro", id: message) + } + + var severity: DiagnosticSeverity { .error } + } +} + +// MARK: - Swift Syntax Extensions + +private extension StringLiteralExprSyntax { + + var textContent: String? { + segments.first?.as(StringSegmentSyntax.self)?.content + .description.trimmingCharacters(in: .init(charactersIn: "\"")) + } +} diff --git a/Tests/KnitMacrosTests/ResolvableTests.swift b/Tests/KnitMacrosTests/ResolvableTests.swift new file mode 100644 index 0000000..a4f4956 --- /dev/null +++ b/Tests/KnitMacrosTests/ResolvableTests.swift @@ -0,0 +1,164 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import KnitMacrosImplementations +import SwiftSyntaxMacros +import SwiftSyntaxMacrosTestSupport +import XCTest + +#if canImport(KnitMacrosImplementations) +import KnitMacrosImplementations + +let testMacros: [String: Macro.Type] = [ + "Resolvable": ResolvableMacro.self +] +#endif + +final class ResolvableTests: XCTestCase { + func test_macro_expansion() throws { + assertMacroExpansion( + """ + @Resolvable + init(arg1: String, arg2: Int) {} + """, + expandedSource: """ + + init(arg1: String, arg2: Int) {} + + static func make(resolver: Resolver) -> Self { + return .init( + arg1: resolver.string(), + arg2: resolver.int() + ) + } + """, + macros: testMacros + ) + } + + func test_closure_param() throws { + assertMacroExpansion( + """ + @Resolvable + init(closure: @escaping () -> Void) {} + """, + expandedSource: """ + + init(closure: @escaping () -> Void) {} + + static func make(resolver: CustomResolver) -> Self { + return .init( + closure: resolver.closure() + ) + } + """, + macros: testMacros + ) + } + + func test_default_param() throws { + assertMacroExpansion( + """ + @Resolvable + init(value: Int = 5) {} + """, + expandedSource: """ + + init(value: Int = 5) {} + + static func make(resolver: Resolver) -> Self { + return .init( + value: 5 + ) + } + """, + macros: testMacros + ) + } + + func test_argument() throws { + assertMacroExpansion( + """ + @Resolvable() + init(value: @Argument Int) {} + """, + expandedSource: """ + + init(value: @Argument Int) {} + + static func make(resolver: Resolver, value: Int) -> Self { + return .init( + value: value + ) + } + """, + macros: testMacros + ) + } + + func test_named() throws { + assertMacroExpansion( + """ + @Resolvable() + init(value: @Named("customName") Int) {} + """, + expandedSource: """ + + init(value: @Named("customName") Int) {} + + static func make(resolver: Resolver) -> Self { + return .init( + value: resolver.int(name: .customName) + ) + } + """, + macros: testMacros + ) + } + + func test_apply_static() throws { + assertMacroExpansion( + """ + @Resolvable + static func makeThing(value: Int) -> Thing { + Thing(value: value) + } + """, + expandedSource: """ + + static func makeThing(value: Int) -> Thing { + Thing(value: value) + } + + static func make(resolver: Resolver) -> Thing { + return makeThing( + value: resolver.int() + ) + } + """, + macros: testMacros + ) + } + + func test_non_static_function() throws { + assertMacroExpansion( + """ + @Resolvable + func makeThing(value: Int) -> Thing { .init() } + """, + expandedSource: """ + + func makeThing(value: Int) -> Thing { .init() } + """, + diagnostics: [ + .init( + message: "@Resolvable can only be used on init declarations or static functions", + line: 1, + column: 1 + ), + ], + macros: testMacros + ) + } +} diff --git a/Tests/KnitMacrosTests/SwinjectResolutionTests.swift b/Tests/KnitMacrosTests/SwinjectResolutionTests.swift new file mode 100644 index 0000000..2121a1e --- /dev/null +++ b/Tests/KnitMacrosTests/SwinjectResolutionTests.swift @@ -0,0 +1,137 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Foundation +import Swinject +import XCTest +import Knit +import KnitMacros + +final class SwinjectResolutionTests: XCTestCase { + + func test_simple_service() { + let container = Factory.container + container.register(Service1.self, factory: Service1.make) + XCTAssertNotNil(container.resolve(Service1.self)) + } + + func test_resolve_closure() { + let container = Factory.container + container.register(Service2.self, factory: Service2.make) + XCTAssertNotNil(container.resolve(Service2.self)) + } + + func test_default_value() { + let emptyContainer = Container() + emptyContainer.register(Service3.self, factory: Service3.make) + let defaultedService = emptyContainer.resolve(Service3.self) + XCTAssertEqual(defaultedService?.value, 2) + } + + /* Disabled due to import issues + func test_argument() { + let container = Container() + container.register(Service4.self, factory: Service4.make) + + let service = container.resolve(Service4.self, argument: Float(5)) + XCTAssertEqual(service?.value, 5) + } + + func test_named_parameter() { + let container = Container() + container.register(Float.self, name: "float2") { _ in 2} + container.register(Service5.self, factory: Service5.make) + + let service = container.resolve(Service5.self) + XCTAssertEqual(service?.value, 2) + } + */ +} + +private struct Service1 { + + let string: String + let value: Int + + @Resolvable + init(string: String, value: Int) { + self.string = string + self.value = value + } +} + +private struct Service2 { + let closure: () -> Void + + @Resolvable() + init(closure: @escaping () -> Void) { + self.closure = closure + } +} + +private struct Service3 { + + let value: Int + + @Resolvable + init(defaultedValue: Int = 2) { + self.value = defaultedValue + } +} +/* +private struct Service4 { + let value: Float + @Resolvable() + init(value: @Argument Float) { + self.value = value + } +} + +private struct Service5 { + let value: Float + @Resolvable() + init(value: @Named("float2") Float) { + self.value = value + } +} +*/ +private enum Factory { + static var container: Container { + let container = Container() + container.register(String.self) { _ in "Test" } + container.register(Int.self) { _ in 5 } + container.register((()->Void).self) { _ in + return { + print("Test") + } + } + + return container + } +} + +// Resolver functions to match what would be generated by Knit + +enum FloatName: String { + case float2 +} + +private extension Resolver { + + func float(name: FloatName) -> Float { + resolve(Float.self, name: name.rawValue)! + } + + func string() -> String { + resolve(String.self)! + } + + func int() -> Int { + resolve(Int.self)! + } + + func closure() -> ()->Void { + resolve((()->Void).self)! + } +}