diff --git a/api/swiftpoet.api b/api/swiftpoet.api index d6404d70..b0d87266 100644 --- a/api/swiftpoet.api +++ b/api/swiftpoet.api @@ -107,7 +107,9 @@ public final class io/outfoxx/swiftpoet/DeclaredTypeName : io/outfoxx/swiftpoet/ public final fun getSimpleName ()Ljava/lang/String; public final fun getSimpleNames ()Ljava/util/List; public final fun nestedType (Ljava/lang/String;Z)Lio/outfoxx/swiftpoet/DeclaredTypeName; + public final fun nestedType (Ljava/util/List;Z)Lio/outfoxx/swiftpoet/DeclaredTypeName; public static synthetic fun nestedType$default (Lio/outfoxx/swiftpoet/DeclaredTypeName;Ljava/lang/String;ZILjava/lang/Object;)Lio/outfoxx/swiftpoet/DeclaredTypeName; + public static synthetic fun nestedType$default (Lio/outfoxx/swiftpoet/DeclaredTypeName;Ljava/util/List;ZILjava/lang/Object;)Lio/outfoxx/swiftpoet/DeclaredTypeName; public final fun peerType (Ljava/lang/String;Z)Lio/outfoxx/swiftpoet/DeclaredTypeName; public static synthetic fun peerType$default (Lio/outfoxx/swiftpoet/DeclaredTypeName;Ljava/lang/String;ZILjava/lang/Object;)Lio/outfoxx/swiftpoet/DeclaredTypeName; public static final fun qualifiedTypeName (Ljava/lang/String;)Lio/outfoxx/swiftpoet/DeclaredTypeName; diff --git a/src/main/java/io/outfoxx/swiftpoet/CodeWriter.kt b/src/main/java/io/outfoxx/swiftpoet/CodeWriter.kt index d7d9e1f2..38a91dcd 100644 --- a/src/main/java/io/outfoxx/swiftpoet/CodeWriter.kt +++ b/src/main/java/io/outfoxx/swiftpoet/CodeWriter.kt @@ -26,12 +26,13 @@ private val NO_MODULE = String() * Converts a [FileSpec] to a string suitable to both human- and swiftc-consumption. This honors * imports, indentation, and variable names. */ -internal class CodeWriter constructor( +internal class CodeWriter( out: Appendable, private val indent: String = DEFAULT_INDENT, internal val importedTypes: Map = emptyMap(), private val importedModules: Set = emptySet() ) : Closeable { + private val out = LineWrapper(out, indent, 100) private var indentLevel = 0 @@ -40,6 +41,7 @@ internal class CodeWriter constructor( private var moduleStack = mutableListOf(NO_MODULE) private val typeSpecStack = mutableListOf() private val importableTypes = mutableMapOf() + private val referencedTypes = mutableMapOf() private var trailingNewline = false /** @@ -205,8 +207,7 @@ internal class CodeWriter constructor( var a = 0 val partIterator = codeBlock.formatParts.listIterator() while (partIterator.hasNext()) { - val part = partIterator.next() - when (part) { + when (val part = partIterator.next()) { "%L" -> emitLiteral(codeBlock.args[a++], isConstantContext) "%N" -> emit(codeBlock.args[a++] as String) @@ -269,33 +270,43 @@ internal class CodeWriter constructor( } } + private fun referenceTypeName(typeName: DeclaredTypeName) { + referencedTypes[typeName.canonicalName] = typeName + } + /** * Returns the best name to identify `typeName` with in the current context. This uses the * available imports and the current scope to find the shortest name available. It does not honor * names visible due to inheritance. */ fun lookupName(typeName: DeclaredTypeName): String { + + // Track all referenced type names, Swift needs to import the module for each type + referenceTypeName(typeName) + + if (typeName.alwaysQualify) { + return typeName.canonicalName + } + // Find the shortest suffix of typeName that resolves to typeName. This uses both local type // names (so `Entry` in `Map` refers to `Map.Entry`). Also uses imports. - var nameResolved = false - var c: DeclaredTypeName? = typeName - while (c != null) { - val simpleName = c.simpleName - val resolved = resolve(simpleName) - nameResolved = resolved != null - - if (resolved == c.unwrapOptional()) { - val suffixOffset = c.simpleNames.size - 1 - return typeName.simpleNames.subList(suffixOffset, typeName.simpleNames.size) - .joinToString(".") + var currentTypeName: DeclaredTypeName? = typeName + val currentNestedSimpleNames = mutableListOf() + while (currentTypeName != null) { + val simpleName = currentTypeName.simpleName + val resolved = resolve(simpleName)?.nestedType(currentNestedSimpleNames) + + if (resolved == typeName.unwrapOptional()) { + // If the type is the same as the type we're resolving for, we must use at least that name. + if (currentNestedSimpleNames.isEmpty()) { + return simpleName + } + // Otherwise, we need to use all the nested names that didn't match + return currentNestedSimpleNames.joinToString(".") } - c = c.enclosingTypeName() - } - - // If the name resolved but wasn't a match, we're stuck with the fully qualified name. - if (nameResolved) { - return typeName.canonicalName + currentNestedSimpleNames.add(0, simpleName) + currentTypeName = currentTypeName.enclosingTypeName() } // If the type is in the same module, we're done. @@ -304,7 +315,7 @@ internal class CodeWriter constructor( } // If the type is in a manually imported module and doesn't clash, use an unqualified type - if (importedModules.contains(typeName.moduleName) && !importableTypes.containsKey(typeName.simpleName)) { + if (importedModules.contains(typeName.moduleName) && !importedTypes.containsKey(typeName.simpleName)) { return typeName.simpleName } @@ -313,7 +324,7 @@ internal class CodeWriter constructor( importableType(typeName) } - return typeName.canonicalName + return resolveImport(typeName) } private fun importableType(typeName: DeclaredTypeName) { @@ -322,10 +333,7 @@ internal class CodeWriter constructor( } val topLevelTypeName = typeName.topLevelTypeName() val simpleName = topLevelTypeName.simpleName - val replaced = importableTypes.put(simpleName, topLevelTypeName) - if (replaced != null) { - importableTypes[simpleName] = replaced // On collision, prefer the first inserted. - } + importableTypes.putIfAbsent(simpleName, topLevelTypeName) } /** @@ -337,12 +345,12 @@ internal class CodeWriter constructor( val typeSpec = typeSpecStack[i] if (typeSpec is ExternalTypeSpec) { if (typeSpec.name == simpleName) { - return stackTypeName(i, simpleName) + return stackTypeName(i) } } for (visibleChild in typeSpec.typeSpecs) { if (visibleChild.name == simpleName) { - return stackTypeName(i, simpleName) + return stackTypeName(i).nestedType(simpleName) } } } @@ -352,21 +360,29 @@ internal class CodeWriter constructor( return DeclaredTypeName(moduleStack.last(), simpleName) } - // Match an imported type. - val importedType = importedTypes[simpleName] - if (importedType != null) return importedType - // No match. return null } + /** + * Looks up `typeName` in the imports and returns the shortest name possible for that type name. + */ + private fun resolveImport(typeName: DeclaredTypeName): String { + val topLevelTypeName = typeName.topLevelTypeName() + return if (importedTypes.values.any { it == topLevelTypeName }) { + typeName.simpleNames.joinToString(".") + } else { + typeName.canonicalName + } + } + /** Returns the type named `simpleName` when nested in the type at `stackDepth`. */ - private fun stackTypeName(stackDepth: Int, simpleName: String): DeclaredTypeName { + private fun stackTypeName(stackDepth: Int): DeclaredTypeName { var typeName = DeclaredTypeName(moduleStack.last(), typeSpecStack[0].name) for (i in 1..stackDepth) { typeName = typeName.nestedType(typeSpecStack[i].name) } - return typeName.nestedType(simpleName) + return typeName } /** @@ -422,54 +438,28 @@ internal class CodeWriter constructor( } /** - * Returns the modules that should have been imported for this code. + * Returns the non-colliding importable types and module names for all referenced types. */ - private fun suggestedImports(): Map { - return importableTypes + private fun generateImports(): Pair, Set> { + return importableTypes to referencedTypes.values.map { it.moduleName }.toSet() } companion object { + /** - * Makes a pass to collect imports by executing [emitStep], and returns an instance of - * [CodeWriter] pre-initialized with collected imports. + * Collect imports by executing [emitStep], and returns the non-colliding imported types + * and referenced modules. */ - fun withCollectedImports( - out: Appendable, + fun collectImports( indent: String, emitStep: (importsCollector: CodeWriter) -> Unit, - ): CodeWriter { - // First pass: emit the entire class, just to collect the types we'll need to import. - val suggestedImports = CodeWriter( - NullAppendable, - indent, - ).use { importsCollector -> - emitStep(importsCollector) - - val generatedImports = mutableMapOf() - importsCollector.suggestedImports() - .generateImports( - generatedImports, - canonicalName = DeclaredTypeName::canonicalName, - ) - } + ): Pair, Set> = + CodeWriter(NullAppendable, indent) + .use { importsCollector -> - return CodeWriter( - out, - indent, - suggestedImports, - ) - } + emitStep(importsCollector) - private fun Map.generateImports( - generatedImports: MutableMap, - canonicalName: T.() -> String, - ): Map { - return flatMap { (simpleName, qualifiedName) -> - listOf(simpleName to qualifiedName).also { - val canonicalName = qualifiedName.canonicalName() - generatedImports[canonicalName] = canonicalName + importsCollector.generateImports() } - }.toMap() - } } } diff --git a/src/main/java/io/outfoxx/swiftpoet/DeclaredTypeName.kt b/src/main/java/io/outfoxx/swiftpoet/DeclaredTypeName.kt index 6be4ffa9..e4c9bbe9 100644 --- a/src/main/java/io/outfoxx/swiftpoet/DeclaredTypeName.kt +++ b/src/main/java/io/outfoxx/swiftpoet/DeclaredTypeName.kt @@ -68,6 +68,13 @@ class DeclaredTypeName internal constructor( fun nestedType(name: String, alwaysQualify: Boolean = this.alwaysQualify) = DeclaredTypeName(names + name, alwaysQualify) + /** + * Returns a new [DeclaredTypeName] instance for the specified `names` as nested inside this + * type. + */ + fun nestedType(names: List, alwaysQualify: Boolean = this.alwaysQualify) = + DeclaredTypeName(this.names + names, alwaysQualify) + /** * Returns a type that shares the same enclosing package or type. If this type is enclosed by * another type, this is equivalent to `enclosingTypeName().nestedType(name)`. Otherwise @@ -82,7 +89,7 @@ class DeclaredTypeName internal constructor( override fun compareTo(other: DeclaredTypeName) = canonicalName.compareTo(other.canonicalName) override fun emit(out: CodeWriter) = - out.emit(escapeKeywords(if (alwaysQualify) canonicalName else out.lookupName(this))) + out.emit(escapeKeywords(out.lookupName(this))) companion object { @JvmStatic fun typeName(qualifiedTypeName: String, alwaysQualify: Boolean = false): DeclaredTypeName { diff --git a/src/main/java/io/outfoxx/swiftpoet/FileSpec.kt b/src/main/java/io/outfoxx/swiftpoet/FileSpec.kt index b0763098..efcfeb12 100644 --- a/src/main/java/io/outfoxx/swiftpoet/FileSpec.kt +++ b/src/main/java/io/outfoxx/swiftpoet/FileSpec.kt @@ -47,12 +47,15 @@ class FileSpec private constructor( @Throws(IOException::class) fun writeTo(out: Appendable) { - val codeWriter = CodeWriter.withCollectedImports( - out = out, - indent = indent, - emitStep = { importsCollector -> emit(importsCollector) }, - ) - codeWriter.use(::emit) + + val (importedTypes, referencedModules) = + CodeWriter.collectImports( + indent = indent, + emitStep = { importsCollector -> emit(importsCollector) }, + ) + + val codeWriter = CodeWriter(out, indent = indent, importedTypes = importedTypes) + emit(codeWriter, referencedModules = referencedModules) } /** Writes this to `directory` as UTF-8 using the standard directory structure. */ @@ -70,19 +73,20 @@ class FileSpec private constructor( @Throws(IOException::class) fun writeTo(directory: File) = writeTo(directory.toPath()) - private fun emit(codeWriter: CodeWriter) { + private fun emit(codeWriter: CodeWriter, referencedModules: Set = setOf()) { if (comment.isNotEmpty()) { codeWriter.emitComment(comment) } codeWriter.pushModule(moduleName) - val importedTypeImports = codeWriter.importedTypes.map { ImportSpec.builder(it.value.moduleName).build() } - val allImports = moduleImports + importedTypeImports - val imports = allImports.filter { it.name != "Swift" } + val implicitModuleImports = referencedModules.map { ImportSpec.builder(it).build() } + val allModuleImports = moduleImports + implicitModuleImports + val nonImportedModules = NON_IMPORTED_MODULES + moduleName + val moduleImports = allModuleImports.filterNot { nonImportedModules.contains(it.name) } - if (imports.isNotEmpty()) { - for (import in imports.toSortedSet()) { + if (moduleImports.isNotEmpty()) { + for (import in moduleImports.toSortedSet()) { import.emit(codeWriter) codeWriter.emit("\n") } @@ -173,6 +177,9 @@ class FileSpec private constructor( } companion object { + + private val NON_IMPORTED_MODULES = setOf("Swift") + @JvmStatic fun get(moduleName: String, typeSpec: AnyTypeSpec): FileSpec { return builder(moduleName, typeSpec.name).addType(typeSpec).build() } diff --git a/src/test/java/io/outfoxx/swiftpoet/test/FileSpecTests.kt b/src/test/java/io/outfoxx/swiftpoet/test/FileSpecTests.kt index cd581aae..403dc70c 100644 --- a/src/test/java/io/outfoxx/swiftpoet/test/FileSpecTests.kt +++ b/src/test/java/io/outfoxx/swiftpoet/test/FileSpecTests.kt @@ -73,6 +73,8 @@ class FileSpecTests { out.toString(), equalTo( """ + import Special + let value: Special.Array """.trimIndent() ) @@ -80,7 +82,7 @@ class FileSpecTests { } @Test - @DisplayName("Generates correct imports when extending different types") + @DisplayName("Generates correct imports when extending different type names") fun testImportsForDifferentExtensionTypes() { val parentElement = typeName("Foundation.Data") val obsElement = typeName("RxSwift.Observable.Element") @@ -121,18 +123,30 @@ class FileSpecTests { } @Test - @DisplayName("Generates correct imports for extension types") + @DisplayName("Generates correct imports for extension type names") fun testImportsForSameExtensionTypes() { + val obs = typeName("RxSwift.Observable") val obsElement = typeName("RxSwift.Observable.Element") + val obsElementSub = typeName("RxSwift.Observable.Element.SubSequence") val extension = ExtensionSpec.builder(obsElement.enclosingTypeName()!!) .addFunction( FunctionSpec.builder("test") + .returns(obs) + .build() + ) + .addFunction( + FunctionSpec.builder("test2") .returns(obsElement) .build() ) + .addFunction( + FunctionSpec.builder("test3") + .returns(obsElementSub) + .build() + ) .build() val testFile = FileSpec.builder("Test", "Test") @@ -150,7 +164,13 @@ class FileSpecTests { extension Observable { - func test() -> RxSwift.Observable.Element { + func test() -> Observable { + } + + func test2() -> Element { + } + + func test3() -> Element.SubSequence { } } @@ -244,10 +264,101 @@ class FileSpecTests { out.toString(), equalTo( """ - class Test { + import Foundation + + class Test { + + let a: Array + let b: Foundation.Array + + } + + """.trimIndent() + ) + ) + } + + @Test + @DisplayName("Generates all required imports with conflicts (alwaysQualify)") + fun testGeneratesAllRequiredImportsWithConflictsUsingAlwaysQualify() { + val type = TypeSpec.structBuilder("SomeType") + .addProperty( + PropertySpec.varBuilder( + "foundation_order", + typeName("Foundation.SortOrder", alwaysQualify = true) + ).build() + ) + .addProperty( + PropertySpec.varBuilder( + "order", + typeName("some_other_module.SortOrder") + ).build() + ) + .build() + + val testFile = FileSpec.builder("Test", "Test") + .addType(type) + .build() + + val out = StringWriter() + testFile.writeTo(out) + + assertThat( + out.toString(), + equalTo( + """ + import Foundation + import some_other_module + + struct SomeType { + + var foundation_order: Foundation.SortOrder + var order: SortOrder + + } + + """.trimIndent() + ) + ) + } + + @Test + @DisplayName("Generates all required imports with conflicts") + fun testGeneratesAllRequiredImportsWithConflicts() { + val type = + TypeSpec.structBuilder("SomeType") + .addProperty( + PropertySpec.varBuilder( + "foundation_order", + typeName("Foundation.SortOrder") + ).build() + ) + .addProperty( + PropertySpec.varBuilder( + "order", + typeName("some_other_module.SortOrder") + ).build() + ) + .build() + + val testFile = FileSpec.builder("Test", "Test") + .addType(type) + .build() + + val out = StringWriter() + testFile.writeTo(out) + + assertThat( + out.toString(), + equalTo( + """ + import Foundation + import some_other_module + + struct SomeType { - let a: Array - let b: Foundation.Array + var foundation_order: SortOrder + var order: some_other_module.SortOrder }