Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.openrewrite.*;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.tree.*;
import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker;

import java.time.Duration;
import java.util.List;
Expand All @@ -28,14 +29,16 @@

@Incubating(since = "7.23.0")
public class RemoveRedundantTypeCast extends Recipe {
private static final String REMOVE_UNNECESSARY_PARENTHESES = "removeUnnecessaryParentheses";

@Override
public String getDisplayName() {
return "Remove redundant casts";
}

@Override
public String getDescription() {
return "Removes unnecessary type casts. Does not currently check casts in lambdas, class constructors, and method invocations.";
return "Removes unnecessary type casts. Does not currently check casts in lambdas and class constructors.";
}

@Override
Expand All @@ -50,7 +53,7 @@ public Set<String> getTags() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return new JavaVisitor<ExecutionContext>() {
return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaVisitor<ExecutionContext>() {
@Override
public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
J visited = super.visitTypeCast(typeCast, ctx);
Expand All @@ -68,8 +71,14 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {

J parentValue = parent.getValue();

J.TypeCast visitedTypeCast = (J.TypeCast) visited;
JavaType expressionType = visitedTypeCast.getExpression().getType();
JavaType castType = visitedTypeCast.getType();

JavaType targetType = null;
if (parentValue instanceof J.VariableDeclarations) {
if (castType.equals(expressionType)) {
targetType = castType;
} else if (parentValue instanceof J.VariableDeclarations) {
targetType = ((J.VariableDeclarations) parentValue).getVariables().get(0).getType();
} else if (parentValue instanceof MethodCall) {
MethodCall methodCall = (MethodCall) parentValue;
Expand All @@ -87,7 +96,10 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
}
}
}
} else if (parentValue instanceof J.Return && ((J.Return) parentValue).getExpression() == typeCast) {
if (TypeUtils.isAssignableTo(castType, expressionType)) {
targetType = castType;
}
} else if (parentValue instanceof J.Return && expressionIsTypeCast((J.Return) parentValue, typeCast)) {
parent = parent.dropParentUntil(is -> is instanceof J.Lambda ||
is instanceof J.MethodDeclaration ||
is instanceof J.ClassDeclaration ||
Expand All @@ -98,10 +110,6 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
}
}

J.TypeCast visitedTypeCast = (J.TypeCast) visited;
JavaType expressionType = visitedTypeCast.getExpression().getType();
JavaType castType = visitedTypeCast.getType();

if (targetType == null) {
return visitedTypeCast;
}
Expand Down Expand Up @@ -135,11 +143,24 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
if (fullyQualified != null) {
maybeRemoveImport(fullyQualified.getFullyQualifiedName());
}
Cursor directParent = getCursor().getParent();
if (directParent != null && directParent.getParent() != null && directParent.getParent().getValue() instanceof J.Parentheses) {
directParent.getParent().putMessage(REMOVE_UNNECESSARY_PARENTHESES, true);
}
return visitedTypeCast.getExpression().withPrefix(visitedTypeCast.getPrefix());
}
return visitedTypeCast;
}

@Override
public <T extends J> J visitParentheses(J.Parentheses<T> parens, ExecutionContext ctx) {
J.Parentheses<T> parentheses = (J.Parentheses<T>) super.visitParentheses(parens, ctx);
if (getCursor().getMessage(REMOVE_UNNECESSARY_PARENTHESES, false)) {
return parentheses.getTree().withPrefix(parentheses.getPrefix());
}
return parentheses;
}

private JavaType getParameterType(JavaType.Method method, int arg) {
List<JavaType> parameterTypes = method.getParameterTypes();
if (parameterTypes.size() > arg) {
Expand All @@ -149,6 +170,20 @@ private JavaType getParameterType(JavaType.Method method, int arg) {
JavaType type = parameterTypes.get(parameterTypes.size() - 1);
return type instanceof JavaType.Array ? ((JavaType.Array) type).getElemType() : type;
}
};

private boolean expressionIsTypeCast(J.Return return_, J.TypeCast typeCast) {
if (return_.getExpression() instanceof J.Parentheses<?>) {
return expressionIsTypeCast((J.Parentheses<?>) return_.getExpression(), typeCast);
}
return return_.getExpression() == typeCast;
}

private boolean expressionIsTypeCast(J.Parentheses<?> parentheses, J.TypeCast typeCast) {
if (parentheses.getTree() instanceof J.Parentheses) {
return expressionIsTypeCast((J.Parentheses<?> ) parentheses.getTree(), typeCast);
}
return parentheses.getTree() == typeCast;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ void removeImport() {

class Test {
List method(List list) {
return (ArrayList) list;
return ((ArrayList) list);
}
}
""",
Expand Down Expand Up @@ -556,4 +556,59 @@ public Bar baz() {
)
);
}

@Test
void chainedMethods() {
// language=java
rewriteRun(
java(
"""
class Bar {
String getName() {
return "The bar";
}
}
class ChildBar extends Bar {}
"""
),
java(
"""
class Foo {
public void getBarName() {
String.format(((Bar) getBar()).getName());
((Bar) getBar()).getName();
((Bar) getChildBar()).getName();
((((Bar) getBar()))).getName();
}

private Bar getBar() {
return new Bar();
}

private ChildBar getChildBar() {
return new ChildBar();
}
}
""",
"""
class Foo {
public void getBarName() {
String.format(getBar().getName());
getBar().getName();
getChildBar().getName();
((getBar())).getName();
}

private Bar getBar() {
return new Bar();
}

private ChildBar getChildBar() {
return new ChildBar();
}
}
"""
)
);
}
}