diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java index ceb2a5d4b2..14471219e3 100644 --- a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java @@ -18,6 +18,7 @@ import org.openrewrite.*; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.tree.*; +import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; import java.time.Duration; import java.util.List; @@ -28,6 +29,8 @@ @Incubating(since = "7.23.0") public class RemoveRedundantTypeCast extends Recipe { + private static final String REMOVE_UNNECESSARY_PARENTHESES = "removeUnnecessaryParentheses"; + @Override public String getDisplayName() { return "Remove redundant casts"; @@ -35,7 +38,7 @@ public String getDisplayName() { @Override public String getDescription() { - return "Removes unnecessary type casts. Does not currently check casts in lambdas, class constructors, and method invocations."; + return "Removes unnecessary type casts. Does not currently check casts in lambdas and class constructors."; } @Override @@ -50,7 +53,7 @@ public Set getTags() { @Override public TreeVisitor getVisitor() { - return new JavaVisitor() { + return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaVisitor() { @Override public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { J visited = super.visitTypeCast(typeCast, ctx); @@ -68,8 +71,14 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { J parentValue = parent.getValue(); + J.TypeCast visitedTypeCast = (J.TypeCast) visited; + JavaType expressionType = visitedTypeCast.getExpression().getType(); + JavaType castType = visitedTypeCast.getType(); + JavaType targetType = null; - if (parentValue instanceof J.VariableDeclarations) { + if (castType.equals(expressionType)) { + targetType = castType; + } else if (parentValue instanceof J.VariableDeclarations) { targetType = ((J.VariableDeclarations) parentValue).getVariables().get(0).getType(); } else if (parentValue instanceof MethodCall) { MethodCall methodCall = (MethodCall) parentValue; @@ -87,7 +96,10 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { } } } - } else if (parentValue instanceof J.Return && ((J.Return) parentValue).getExpression() == typeCast) { + if (TypeUtils.isAssignableTo(castType, expressionType)) { + targetType = castType; + } + } else if (parentValue instanceof J.Return && expressionIsTypeCast((J.Return) parentValue, typeCast)) { parent = parent.dropParentUntil(is -> is instanceof J.Lambda || is instanceof J.MethodDeclaration || is instanceof J.ClassDeclaration || @@ -98,10 +110,6 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { } } - J.TypeCast visitedTypeCast = (J.TypeCast) visited; - JavaType expressionType = visitedTypeCast.getExpression().getType(); - JavaType castType = visitedTypeCast.getType(); - if (targetType == null) { return visitedTypeCast; } @@ -135,11 +143,24 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { if (fullyQualified != null) { maybeRemoveImport(fullyQualified.getFullyQualifiedName()); } + Cursor directParent = getCursor().getParent(); + if (directParent != null && directParent.getParent() != null && directParent.getParent().getValue() instanceof J.Parentheses) { + directParent.getParent().putMessage(REMOVE_UNNECESSARY_PARENTHESES, true); + } return visitedTypeCast.getExpression().withPrefix(visitedTypeCast.getPrefix()); } return visitedTypeCast; } + @Override + public J visitParentheses(J.Parentheses parens, ExecutionContext ctx) { + J.Parentheses parentheses = (J.Parentheses) super.visitParentheses(parens, ctx); + if (getCursor().getMessage(REMOVE_UNNECESSARY_PARENTHESES, false)) { + return parentheses.getTree().withPrefix(parentheses.getPrefix()); + } + return parentheses; + } + private JavaType getParameterType(JavaType.Method method, int arg) { List parameterTypes = method.getParameterTypes(); if (parameterTypes.size() > arg) { @@ -149,6 +170,20 @@ private JavaType getParameterType(JavaType.Method method, int arg) { JavaType type = parameterTypes.get(parameterTypes.size() - 1); return type instanceof JavaType.Array ? ((JavaType.Array) type).getElemType() : type; } - }; + + private boolean expressionIsTypeCast(J.Return return_, J.TypeCast typeCast) { + if (return_.getExpression() instanceof J.Parentheses) { + return expressionIsTypeCast((J.Parentheses) return_.getExpression(), typeCast); + } + return return_.getExpression() == typeCast; + } + + private boolean expressionIsTypeCast(J.Parentheses parentheses, J.TypeCast typeCast) { + if (parentheses.getTree() instanceof J.Parentheses) { + return expressionIsTypeCast((J.Parentheses ) parentheses.getTree(), typeCast); + } + return parentheses.getTree() == typeCast; + } + }); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java index 87c69c0218..95ae3ab070 100644 --- a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java @@ -471,7 +471,7 @@ void removeImport() { class Test { List method(List list) { - return (ArrayList) list; + return ((ArrayList) list); } } """, @@ -556,4 +556,59 @@ public Bar baz() { ) ); } + + @Test + void chainedMethods() { + // language=java + rewriteRun( + java( + """ + class Bar { + String getName() { + return "The bar"; + } + } + class ChildBar extends Bar {} + """ + ), + java( + """ + class Foo { + public void getBarName() { + String.format(((Bar) getBar()).getName()); + ((Bar) getBar()).getName(); + ((Bar) getChildBar()).getName(); + ((((Bar) getBar()))).getName(); + } + + private Bar getBar() { + return new Bar(); + } + + private ChildBar getChildBar() { + return new ChildBar(); + } + } + """, + """ + class Foo { + public void getBarName() { + String.format(getBar().getName()); + getBar().getName(); + getChildBar().getName(); + ((getBar())).getName(); + } + + private Bar getBar() { + return new Bar(); + } + + private ChildBar getChildBar() { + return new ChildBar(); + } + } + """ + ) + ); + } }