diff --git a/src/main/java/org/openrewrite/java/template/internal/AbstractRefasterJavaVisitor.java b/src/main/java/org/openrewrite/java/template/internal/AbstractRefasterJavaVisitor.java index 8c1824d..5b978de 100644 --- a/src/main/java/org/openrewrite/java/template/internal/AbstractRefasterJavaVisitor.java +++ b/src/main/java/org/openrewrite/java/template/internal/AbstractRefasterJavaVisitor.java @@ -23,15 +23,93 @@ import org.openrewrite.java.cleanup.SimplifyBooleanExpressionVisitor; import org.openrewrite.java.cleanup.UnnecessaryParenthesesVisitor; import org.openrewrite.java.service.ImportService; +import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.TypeUtils; import java.util.EnumSet; +import java.util.List; import static org.openrewrite.java.MethodMatcher.methodPattern; @SuppressWarnings("unused") public abstract class AbstractRefasterJavaVisitor extends JavaVisitor { + /** + * Check whether the after template's return type is assignable to the target type + * expected by the surrounding context (e.g., receiver of a chained method call, + * argument to a method, or right-hand side of an assignment). + * Returns {@code true} if the replacement is safe: either the context doesn't constrain + * the type (standalone expression, return statement, etc.) or the after type is assignable + * to what the context expects. Returns {@code false} if the replacement would break + * compilation (e.g., a chained method call that doesn't exist on the wider type). + */ + protected boolean isAssignableToTargetType(String afterTypeFqn) { + Cursor parentCursor = getCursor().getParentTreeCursor(); + Object parent = parentCursor.getValue(); + Object child = getCursor().getValue(); + + if (parent instanceof J.MethodInvocation) { + J.MethodInvocation mi = (J.MethodInvocation) parent; + if (mi.getMethodType() == null) { + return true; + } + // Expression is the receiver — the method must exist on the replacement type + if (mi.getSelect() == child) { + return TypeUtils.isAssignableTo(afterTypeFqn, mi.getMethodType().getDeclaringType()); + } + // Expression is an argument — check resolved overload first, then other overloads + List args = mi.getArguments(); + int argIndex = -1; + for (int i = 0; i < args.size(); i++) { + if (args.get(i) == child) { + argIndex = i; + break; + } + } + List parameterTypes = mi.getMethodType().getParameterTypes(); + if (argIndex >= 0 && argIndex < parameterTypes.size()) { + if (TypeUtils.isAssignableTo(afterTypeFqn, parameterTypes.get(argIndex))) { + return true; + } + // Check if any other overload on the declaring type accepts the wider type + JavaType.FullyQualified declaringType = mi.getMethodType().getDeclaringType(); + String methodName = mi.getMethodType().getName(); + for (JavaType.Method method : declaringType.getMethods()) { + if (!method.getName().equals(methodName) || method.getParameterTypes().size() != args.size()) { + continue; + } + List paramTypes = method.getParameterTypes(); + boolean allMatch = true; + for (int i = 0; i < args.size(); i++) { + JavaType argType = i == argIndex ? JavaType.buildType(afterTypeFqn) : args.get(i).getType(); + if (argType == null || !TypeUtils.isAssignableTo(paramTypes.get(i), argType)) { + allMatch = false; + break; + } + } + if (allMatch) { + return true; + } + } + return false; + } + } else if (parent instanceof J.Assignment) { + J.Assignment assignment = (J.Assignment) parent; + if (assignment.getAssignment() == child) { + return TypeUtils.isAssignableTo(afterTypeFqn, assignment.getType()); + } + } else if (parent instanceof J.VariableDeclarations.NamedVariable) { + J.VariableDeclarations.NamedVariable var = (J.VariableDeclarations.NamedVariable) parent; + if (var.getInitializer() == child) { + return TypeUtils.isAssignableTo(afterTypeFqn, var.getType()); + } + } + // No constraint from context (standalone expression, return statement, etc.) + return true; + } + @SuppressWarnings("SameParameterValue") protected J embed(J j, Cursor cursor, ExecutionContext ctx, EmbeddingOption... options) { EnumSet optionsSet = options.length > 0 ? EnumSet.of(options[0], options) : diff --git a/src/main/java/org/openrewrite/java/template/processor/RecipeWriter.java b/src/main/java/org/openrewrite/java/template/processor/RecipeWriter.java index f1111d6..ffd220c 100644 --- a/src/main/java/org/openrewrite/java/template/processor/RecipeWriter.java +++ b/src/main/java/org/openrewrite/java/template/processor/RecipeWriter.java @@ -18,6 +18,7 @@ import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.TypeTag; +import com.sun.tools.javac.code.Types; import com.sun.tools.javac.parser.Tokens; import com.sun.tools.javac.processing.JavacProcessingEnvironment; import com.sun.tools.javac.tree.JCTree; @@ -351,6 +352,30 @@ private String generateVisitMethod(Map beforeTemplat } visitMethod.append(" JavaTemplate.Matcher matcher;\n"); + + // Check if any before template needs the type assignability guard + String hoistedGuardType = null; + if (descriptor.afterTemplate != null) { + Types types = Types.instance(processingEnv.getContext()); + Type afterReturnType = descriptor.afterTemplate.method.getReturnType().type; + if (!(afterReturnType instanceof Type.JCVoidType)) { + for (TemplateDescriptor bt : beforeTemplates.values()) { + Type beforeReturnType = bt.method.getReturnType().type; + if (!(beforeReturnType instanceof Type.JCVoidType) && + !types.isSubtype(types.erasure(afterReturnType), types.erasure(beforeReturnType))) { + hoistedGuardType = types.erasure(afterReturnType).tsym.getQualifiedName().toString(); + break; + } + } + } + } + if (hoistedGuardType != null) { + visitMethod.append(" if (!isAssignableToTargetType(\"") + .append(hoistedGuardType).append("\")) {\n"); + visitMethod.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + visitMethod.append(" }\n"); + } + for (Map.Entry entry : beforeTemplates.entrySet()) { int arity = entry.getValue().getArity(); for (int i = 0; i < arity; i++) { diff --git a/src/test/resources/refaster/PreconditionsVerifierRecipes.java b/src/test/resources/refaster/PreconditionsVerifierRecipes.java index defccdb..7a8596b 100644 --- a/src/test/resources/refaster/PreconditionsVerifierRecipes.java +++ b/src/test/resources/refaster/PreconditionsVerifierRecipes.java @@ -180,6 +180,9 @@ public TreeVisitor getVisitor() { @Override public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { JavaTemplate.Matcher matcher; + if (!isAssignableToTargetType("java.lang.Object")) { + return super.visitMethodInvocation(elem, ctx); + } if (before == null) { before = JavaTemplate.builder("com.google.common.base.Strings.nullToEmpty(#{value:any(java.lang.String)})") .bindType("java.lang.String") @@ -272,6 +275,9 @@ public TreeVisitor getVisitor() { @Override public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { JavaTemplate.Matcher matcher; + if (!isAssignableToTargetType("java.lang.Object")) { + return super.visitMethodInvocation(elem, ctx); + } if (before == null) { before = JavaTemplate.builder("com.google.common.base.Strings.nullToEmpty(#{value:any(java.lang.String)})") .bindType("java.lang.String")