diff --git a/src/main/java/org/openrewrite/staticanalysis/UseCollectionInterfaces.java b/src/main/java/org/openrewrite/staticanalysis/UseCollectionInterfaces.java index 0127d899da..526329e7ab 100644 --- a/src/main/java/org/openrewrite/staticanalysis/UseCollectionInterfaces.java +++ b/src/main/java/org/openrewrite/staticanalysis/UseCollectionInterfaces.java @@ -15,6 +15,7 @@ */ package org.openrewrite.staticanalysis; +import org.jetbrains.annotations.Contract; import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.groovy.tree.G; @@ -114,46 +115,6 @@ public J visit(@Nullable Tree tree, ExecutionContext ctx) { return super.visit(tree, ctx); } - @Override - public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { - J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); - if ((m.hasModifier(J.Modifier.Type.Public) || m.hasModifier(J.Modifier.Type.Private) || m.getModifiers().isEmpty()) && - m.getReturnTypeExpression() != null) { - JavaType.FullyQualified originalType = TypeUtils.asFullyQualified(m.getReturnTypeExpression().getType()); - if (originalType != null && rspecRulesReplaceTypeMap.containsKey(originalType.getFullyQualifiedName())) { - - JavaType.FullyQualified newType = TypeUtils.asFullyQualified( - JavaType.buildType(rspecRulesReplaceTypeMap.get(originalType.getFullyQualifiedName()))); - if (newType != null) { - maybeRemoveImport(originalType); - maybeAddImport(newType); - - TypeTree typeExpression; - if (m.getReturnTypeExpression() instanceof J.Identifier) { - typeExpression = new J.Identifier( - randomId(), - m.getReturnTypeExpression().getPrefix(), - Markers.EMPTY, - emptyList(), - newType.getClassName(), - newType, - null - ); - } else if (m.getReturnTypeExpression() instanceof J.AnnotatedType) { - J.AnnotatedType annotatedType = (J.AnnotatedType) m.getReturnTypeExpression(); - J.ParameterizedType parameterizedType = (J.ParameterizedType) annotatedType.getTypeExpression(); - typeExpression = annotatedType.withTypeExpression(removeFromParameterizedType(newType, parameterizedType)); - } else { - J.ParameterizedType parameterizedType = (J.ParameterizedType) m.getReturnTypeExpression(); - typeExpression = removeFromParameterizedType(newType, parameterizedType); - } - m = m.withReturnTypeExpression(typeExpression); - } - } - } - return m; - } - @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); @@ -176,6 +137,41 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return mi; } + private J.MethodInvocation updateMethodInvocation(J.MethodInvocation mi, JavaType.FullyQualified newType) { + if (mi.getSelect() != null) { + mi = mi.withSelect(mi.getSelect().withType(newType)); + if (mi.getSelect() instanceof J.FieldAccess) { + J.FieldAccess fieldAccess = (J.FieldAccess) mi.getSelect(); + mi = mi.withSelect(fieldAccess.withName(fieldAccess.getName().withType(newType))); + } + } + if (mi.getMethodType() != null) { + mi = mi.withMethodType(mi.getMethodType().withDeclaringType(newType)); + } + return mi; + } + + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); + if ((m.hasModifier(J.Modifier.Type.Public) || m.hasModifier(J.Modifier.Type.Private) || m.getModifiers().isEmpty()) && + m.getReturnTypeExpression() != null) { + JavaType.FullyQualified originalType = TypeUtils.asFullyQualified(m.getReturnTypeExpression().getType()); + if (originalType != null && rspecRulesReplaceTypeMap.containsKey(originalType.getFullyQualifiedName())) { + + JavaType.FullyQualified newType = TypeUtils.asFullyQualified( + JavaType.buildType(rspecRulesReplaceTypeMap.get(originalType.getFullyQualifiedName()))); + if (newType != null) { + maybeRemoveImport(originalType); + maybeAddImport(newType); + + m = m.withReturnTypeExpression(getTypeTree(m.getReturnTypeExpression(), newType)); + } + } + } + return m; + } + @Override public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) { J.VariableDeclarations mv = super.visitVariableDeclarations(multiVariable, ctx); @@ -192,29 +188,7 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations m maybeRemoveImport(originalType); maybeAddImport(newType); - TypeTree typeExpression; - if (mv.getTypeExpression() == null) { - typeExpression = null; - } else if (mv.getTypeExpression() instanceof J.Identifier) { - typeExpression = new J.Identifier( - randomId(), - mv.getTypeExpression().getPrefix(), - Markers.EMPTY, - emptyList(), - newType.getClassName(), - newType, - null - ); - } else if (mv.getTypeExpression() instanceof J.AnnotatedType) { - J.AnnotatedType annotatedType = (J.AnnotatedType) mv.getTypeExpression(); - J.ParameterizedType parameterizedType = (J.ParameterizedType) annotatedType.getTypeExpression(); - typeExpression = annotatedType.withTypeExpression(removeFromParameterizedType(newType, parameterizedType)); - } else { - J.ParameterizedType parameterizedType = (J.ParameterizedType) mv.getTypeExpression(); - typeExpression = removeFromParameterizedType(newType, parameterizedType); - } - - mv = mv.withTypeExpression(typeExpression); + mv = mv.withTypeExpression(getTypeTree(mv.getTypeExpression(), newType)); mv = mv.withVariables(ListUtils.map(mv.getVariables(), var -> { JavaType.FullyQualified varType = TypeUtils.asFullyQualified(var.getType()); if (varType != null && !varType.equals(newType)) { @@ -227,18 +201,53 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations m return mv; } - private J.MethodInvocation updateMethodInvocation(J.MethodInvocation mi, JavaType.FullyQualified newType) { - if (mi.getSelect() != null) { - mi = mi.withSelect(mi.getSelect().withType(newType)); - if (mi.getSelect() instanceof J.FieldAccess) { - J.FieldAccess fieldAccess = (J.FieldAccess) mi.getSelect(); - mi = mi.withSelect(fieldAccess.withName(fieldAccess.getName().withType(newType))); - } + @Contract("null, _ -> null") + private @Nullable TypeTree getTypeTree(@Nullable TypeTree inputTypeExpression, JavaType.FullyQualified newType) { + if (inputTypeExpression == null) { + return null; } - if (mi.getMethodType() != null) { - mi = mi.withMethodType(mi.getMethodType().withDeclaringType(newType)); + if (inputTypeExpression instanceof J.Identifier) { + return new J.Identifier( + randomId(), + inputTypeExpression.getPrefix(), + Markers.EMPTY, + emptyList(), + newType.getClassName(), + newType, + null + ); } - return mi; + if (inputTypeExpression instanceof J.FieldAccess) { + // Fully-qualified type name like java.util.HashSet + return new J.Identifier( + randomId(), + inputTypeExpression.getPrefix(), + Markers.EMPTY, + emptyList(), + newType.getClassName(), + newType, + null + ); + } + if (inputTypeExpression instanceof J.AnnotatedType) { + J.AnnotatedType annotatedType = (J.AnnotatedType) inputTypeExpression; + TypeTree annotatedTypeExpression = annotatedType.getTypeExpression(); + if (annotatedTypeExpression instanceof J.Identifier || annotatedTypeExpression instanceof J.FieldAccess) { + return annotatedType.withTypeExpression(new J.Identifier( + randomId(), + annotatedTypeExpression.getPrefix(), + Markers.EMPTY, + emptyList(), + newType.getClassName(), + newType, + null + )); + } + J.ParameterizedType parameterizedType = (J.ParameterizedType) annotatedTypeExpression; + return annotatedType.withTypeExpression(removeFromParameterizedType(newType, parameterizedType)); + } + J.ParameterizedType parameterizedType = (J.ParameterizedType) inputTypeExpression; + return removeFromParameterizedType(newType, parameterizedType); } private TypeTree removeFromParameterizedType(JavaType.FullyQualified newType, diff --git a/src/test/java/org/openrewrite/staticanalysis/UseCollectionInterfacesTest.java b/src/test/java/org/openrewrite/staticanalysis/UseCollectionInterfacesTest.java index 6429936454..ba344da8b2 100644 --- a/src/test/java/org/openrewrite/staticanalysis/UseCollectionInterfacesTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/UseCollectionInterfacesTest.java @@ -1211,4 +1211,142 @@ Enumeration usesVectorElements() { ) ); } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/713") + @Test + void annotatedReturnTypeRawArrayList() { + rewriteRun( + spec -> spec.parser(JavaParser.fromJavaVersion().classpath("jspecify")), + //language=java + java( + """ + import java.util.ArrayList; + import org.jspecify.annotations.Nullable; + + class Test { + public @Nullable ArrayList transform() { + ArrayList res = new ArrayList(); + return res; + } + } + """, + """ + import java.util.ArrayList; + import java.util.List; + + import org.jspecify.annotations.Nullable; + + class Test { + public @Nullable List transform() { + List res = new ArrayList(); + return res; + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/713") + @Test + void annotatedFieldTypeRawArrayList() { + rewriteRun( + spec -> spec.parser(JavaParser.fromJavaVersion().classpath("jspecify")), + //language=java + java( + """ + import java.util.ArrayList; + import org.jspecify.annotations.Nullable; + + class Test { + public @Nullable ArrayList values = new ArrayList(); + } + """, + """ + import java.util.ArrayList; + import java.util.List; + + import org.jspecify.annotations.Nullable; + + class Test { + public @Nullable List values = new ArrayList(); + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/716") + @Test + void fullyQualifiedRawType() { + rewriteRun( + //language=java + java( + """ + class Test { + public java.util.HashSet method() { + return new java.util.HashSet<>(); + } + } + """, + """ + import java.util.Set; + + class Test { + public Set method() { + return new java.util.HashSet<>(); + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/716") + @Test + void fullyQualifiedParameterizedType() { + rewriteRun( + //language=java + java( + """ + class Test { + public java.util.HashSet method() { + return new java.util.HashSet<>(); + } + } + """, + """ + import java.util.Set; + + class Test { + public Set method() { + return new java.util.HashSet<>(); + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/716") + @Test + void fullyQualifiedFieldType() { + rewriteRun( + //language=java + java( + """ + class Test { + public java.util.HashSet values = new java.util.HashSet(); + } + """, + """ + import java.util.Set; + + class Test { + public Set values = new java.util.HashSet(); + } + """ + ) + ); + } }