Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,93 @@
import org.openrewrite.java.cleanup.SimplifyBooleanExpressionVisitor;
import org.openrewrite.java.cleanup.UnnecessaryParenthesesVisitor;
import org.openrewrite.java.service.ImportService;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.TypeUtils;

import java.util.EnumSet;
import java.util.List;

import static org.openrewrite.java.MethodMatcher.methodPattern;

@SuppressWarnings("unused")
public abstract class AbstractRefasterJavaVisitor extends JavaVisitor<ExecutionContext> {

/**
* Check whether the after template's return type is assignable to the target type
* expected by the surrounding context (e.g., receiver of a chained method call,
* argument to a method, or right-hand side of an assignment).
* Returns {@code true} if the replacement is safe: either the context doesn't constrain
* the type (standalone expression, return statement, etc.) or the after type is assignable
* to what the context expects. Returns {@code false} if the replacement would break
* compilation (e.g., a chained method call that doesn't exist on the wider type).
*/
protected boolean isAssignableToTargetType(String afterTypeFqn) {
Cursor parentCursor = getCursor().getParentTreeCursor();
Object parent = parentCursor.getValue();
Object child = getCursor().getValue();

if (parent instanceof J.MethodInvocation) {
J.MethodInvocation mi = (J.MethodInvocation) parent;
if (mi.getMethodType() == null) {
return true;
}
// Expression is the receiver — the method must exist on the replacement type
if (mi.getSelect() == child) {
return TypeUtils.isAssignableTo(afterTypeFqn, mi.getMethodType().getDeclaringType());
}
// Expression is an argument — check resolved overload first, then other overloads
List<Expression> args = mi.getArguments();
int argIndex = -1;
for (int i = 0; i < args.size(); i++) {
if (args.get(i) == child) {
argIndex = i;
break;
}
}
List<JavaType> parameterTypes = mi.getMethodType().getParameterTypes();
if (argIndex >= 0 && argIndex < parameterTypes.size()) {
if (TypeUtils.isAssignableTo(afterTypeFqn, parameterTypes.get(argIndex))) {
return true;
}
// Check if any other overload on the declaring type accepts the wider type
JavaType.FullyQualified declaringType = mi.getMethodType().getDeclaringType();
String methodName = mi.getMethodType().getName();
for (JavaType.Method method : declaringType.getMethods()) {
if (!method.getName().equals(methodName) || method.getParameterTypes().size() != args.size()) {
continue;
}
List<JavaType> paramTypes = method.getParameterTypes();
boolean allMatch = true;
for (int i = 0; i < args.size(); i++) {
JavaType argType = i == argIndex ? JavaType.buildType(afterTypeFqn) : args.get(i).getType();
if (argType == null || !TypeUtils.isAssignableTo(paramTypes.get(i), argType)) {
allMatch = false;
break;
}
}
if (allMatch) {
return true;
}
}
return false;
}
} else if (parent instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) parent;
if (assignment.getAssignment() == child) {
return TypeUtils.isAssignableTo(afterTypeFqn, assignment.getType());
}
} else if (parent instanceof J.VariableDeclarations.NamedVariable) {
J.VariableDeclarations.NamedVariable var = (J.VariableDeclarations.NamedVariable) parent;
if (var.getInitializer() == child) {
return TypeUtils.isAssignableTo(afterTypeFqn, var.getType());
}
}
// No constraint from context (standalone expression, return statement, etc.)
return true;
}

@SuppressWarnings("SameParameterValue")
protected J embed(J j, Cursor cursor, ExecutionContext ctx, EmbeddingOption... options) {
EnumSet<EmbeddingOption> optionsSet = options.length > 0 ? EnumSet.of(options[0], options) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.TypeTag;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.parser.Tokens;
import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import com.sun.tools.javac.tree.JCTree;
Expand Down Expand Up @@ -351,6 +352,30 @@ private String generateVisitMethod(Map<String, TemplateDescriptor> beforeTemplat
}

visitMethod.append(" JavaTemplate.Matcher matcher;\n");

// Check if any before template needs the type assignability guard
String hoistedGuardType = null;
if (descriptor.afterTemplate != null) {
Types types = Types.instance(processingEnv.getContext());
Type afterReturnType = descriptor.afterTemplate.method.getReturnType().type;
if (!(afterReturnType instanceof Type.JCVoidType)) {
for (TemplateDescriptor bt : beforeTemplates.values()) {
Type beforeReturnType = bt.method.getReturnType().type;
if (!(beforeReturnType instanceof Type.JCVoidType) &&
!types.isSubtype(types.erasure(afterReturnType), types.erasure(beforeReturnType))) {
hoistedGuardType = types.erasure(afterReturnType).tsym.getQualifiedName().toString();
break;
}
}
}
}
if (hoistedGuardType != null) {
visitMethod.append(" if (!isAssignableToTargetType(\"")
.append(hoistedGuardType).append("\")) {\n");
visitMethod.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n");
visitMethod.append(" }\n");
}

for (Map.Entry<String, TemplateDescriptor> entry : beforeTemplates.entrySet()) {
int arity = entry.getValue().getArity();
for (int i = 0; i < arity; i++) {
Expand Down
6 changes: 6 additions & 0 deletions src/test/resources/refaster/PreconditionsVerifierRecipes.java
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
@Override
public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) {
JavaTemplate.Matcher matcher;
if (!isAssignableToTargetType("java.lang.Object")) {
return super.visitMethodInvocation(elem, ctx);
}
if (before == null) {
before = JavaTemplate.builder("com.google.common.base.Strings.nullToEmpty(#{value:any(java.lang.String)})")
.bindType("java.lang.String")
Expand Down Expand Up @@ -272,6 +275,9 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
@Override
public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) {
JavaTemplate.Matcher matcher;
if (!isAssignableToTargetType("java.lang.Object")) {
return super.visitMethodInvocation(elem, ctx);
}
if (before == null) {
before = JavaTemplate.builder("com.google.common.base.Strings.nullToEmpty(#{value:any(java.lang.String)})")
.bindType("java.lang.String")
Expand Down