diff --git a/src/main/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionals.java b/src/main/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionals.java index 831b4f9d..8579963f 100644 --- a/src/main/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionals.java +++ b/src/main/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionals.java @@ -19,7 +19,6 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; -import org.openrewrite.internal.ListUtils; import org.openrewrite.internal.StringUtils; import org.openrewrite.java.*; import org.openrewrite.java.search.UsesMethod; @@ -72,24 +71,23 @@ private static class AddIfEnabledVisitor extends JavaVisitor { @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); - if ((infoMatcher.matches(m) || debugMatcher.matches(m) || traceMatcher.matches(m)) && - !isInIfStatementWithLogLevelCheck(getCursor(), m)) { - List arguments = ListUtils.filter(m.getArguments(), a -> a instanceof J.MethodInvocation); - if (m.getSelect() != null && !arguments.isEmpty()) { - J container = getCursor().getParentTreeCursor().getValue(); - if (container instanceof J.Block) { - UUID id = container.getId(); - J.If if_ = ((J.If) JavaTemplate - .builder("if(#{logger:any(org.slf4j.Logger)}.is#{}Enabled()) {}") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "slf4j-api-2.+")) - .build() - .apply(getCursor(), m.getCoordinates().replace(), - m.getSelect(), StringUtils.capitalize(m.getSimpleName()))) - .withThenPart(m.withPrefix(m.getPrefix().withWhitespace("\n" + m.getPrefix().getWhitespace().replace("\n", "")))) - .withPrefix(m.getPrefix().withComments(emptyList())); - visitedBlocks.add(id); - return if_; - } + if (m.getSelect() != null && + (infoMatcher.matches(m) || debugMatcher.matches(m) || traceMatcher.matches(m)) && + !isInIfStatementWithLogLevelCheck(getCursor(), m) && + isAnyArgumentExpensive(m)) { + J container = getCursor().getParentTreeCursor().getValue(); + if (container instanceof J.Block) { + UUID id = container.getId(); + J.If if_ = ((J.If) JavaTemplate + .builder("if(#{logger:any(org.slf4j.Logger)}.is#{}Enabled()) {}") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "slf4j-api-2.+")) + .build() + .apply(getCursor(), m.getCoordinates().replace(), + m.getSelect(), StringUtils.capitalize(m.getSimpleName()))) + .withThenPart(m.withPrefix(m.getPrefix().withWhitespace("\n" + m.getPrefix().getWhitespace().replace("\n", "")))) + .withPrefix(m.getPrefix().withComments(emptyList())); + visitedBlocks.add(id); + return if_; } } return m; @@ -114,6 +112,55 @@ private boolean isInIfStatementWithLogLevelCheck(Cursor cursor, J.MethodInvocati (debugMatcher.matches(m) && sideEffects.stream().allMatch(e -> e instanceof J.MethodInvocation && isDebugEnabledMatcher.matches((J.MethodInvocation) e))) || (traceMatcher.matches(m) && sideEffects.stream().allMatch(e -> e instanceof J.MethodInvocation && isTraceEnabledMatcher.matches((J.MethodInvocation) e))); } + + private boolean isAnyArgumentExpensive(J.MethodInvocation m) { + return m + .getArguments() + .stream() + .anyMatch(arg -> + !(arg instanceof J.MethodInvocation && isSimpleGetter((J.MethodInvocation) arg) || + arg instanceof J.Literal || + arg instanceof J.Identifier || + arg instanceof J.FieldAccess || + arg instanceof J.Binary && isOnlyLiterals((J.Binary) arg)) + ); + } + + private static boolean isSimpleGetter(J.MethodInvocation mi) { + return ((mi.getSimpleName().startsWith("get") && mi.getSimpleName().length() > 3) || + (mi.getSimpleName().startsWith("is") && mi.getSimpleName().length() > 2)) && + mi.getMethodType() != null && + mi.getMethodType().getParameterNames().isEmpty() && + ((mi.getSelect() == null || mi.getSelect() instanceof J.Identifier) && + !mi.getMethodType().hasFlags(Flag.Static)); + } + + private static boolean isOnlyLiterals(J.Binary binary) { + return isLiteralOrBinary(binary.getLeft()) && isLiteralOrBinary(binary.getRight()); + } + + private static boolean isLiteralOrBinary(J expression) { + return expression instanceof J.Literal || + isSimpleBooleanGetter(expression) || + isBooleanIdentifier(expression) || + expression instanceof J.Binary && isOnlyLiterals((J.Binary) expression); + } + + private static boolean isSimpleBooleanGetter(J expression) { + if (expression instanceof J.MethodInvocation) { + J.MethodInvocation mi = (J.MethodInvocation) expression; + return isSimpleGetter(mi) && mi.getMethodType() != null && isTypeBoolean(mi.getMethodType().getReturnType()); + } + return false; + } + + private static boolean isBooleanIdentifier(J expression) { + return expression instanceof J.Identifier && isTypeBoolean(((J.Identifier) expression).getType()); + } + + private static boolean isTypeBoolean(@Nullable JavaType type) { + return type == JavaType.Primitive.Boolean || TypeUtils.isAssignableTo("java.lang.Boolean", type); + } } @EqualsAndHashCode(callSuper = false) diff --git a/src/test/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionalsTest.java b/src/test/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionalsTest.java index a57f5be8..12efb3f5 100644 --- a/src/test/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionalsTest.java +++ b/src/test/java/org/openrewrite/java/logging/slf4j/WrapExpensiveLogStatementsInConditionalsTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; import org.openrewrite.java.JavaParser; @@ -580,4 +581,113 @@ String expensiveOp() { ) ); } + + @ValueSource(strings = { + "notAGetter()", // not a getter + "new A()", // allocating a new object + "new A().getClass()", // allocating a new object first + "input.getBytes(StandardCharsets.UTF_16)", // getter with an argument + "getClass().getName()", // getter on a method invocation expression + "optional.get()", // not a getter + "A.getExpensive()", // static getter likely to use external resources or allocate things + "getExpensive()", // static getter likely to use external resources or allocate things + "342 + input", // allocating a new string + "\"foo\" + getClass()", // allocating a new string + "true && isSomething(1)" // boolean getter with an argument + }) + @ParameterizedTest + void wrapWhenExpensiveArgument(String logArgument) { + //language=java + rewriteRun( + java( + String.format(""" + import java.nio.charset.StandardCharsets; + import java.util.Optional; + import org.slf4j.Logger; + + class A { + void method(Logger log, String input, Optional optional, boolean boolVariable) { + log.info("{}", %s); + } + + String notAGetter() { + return "property"; + } + + static String getExpensive() { + return "expensive"; + } + + boolean isSomething(int i) { + return true; + } + } + """, logArgument), + String.format(""" + import java.nio.charset.StandardCharsets; + import java.util.Optional; + import org.slf4j.Logger; + + class A { + void method(Logger log, String input, Optional optional, boolean boolVariable) { + if (log.isInfoEnabled()) { + log.info("{}", %s); + } + } + + String notAGetter() { + return "property"; + } + + static String getExpensive() { + return "expensive"; + } + + boolean isSomething(int i) { + return true; + } + } + """, logArgument) + ) + ); + } + + @ValueSource(strings = { + "input", // identifier alone + "getClass()", // a getter + "log.getName()", // a getter + "34 + 78", // literal + "8344", // literal + "\"like, literally!\"", // literal + "\"one\" + \"two\" + \"three\"", // compile time literal + "\"one\" + 1", // compile time literal + "true && false", // boolean literal + "true && isSomething()", // boolean literal and boolean getter + "true && boolVariable || isSomething()", // boolean literal and boolean variable + "field", // field identifier + "this.field", // field access + }) + @ParameterizedTest + void dontWrapWhenCheapArgument(String logArgument) { + //language=java + rewriteRun( + java( + String.format(""" + import org.slf4j.Logger; + + class A { + String field; + + void method(Logger log, String input, boolean boolVariable) { + log.info("{}", %s); + } + + boolean isSomething() { + return true; + } + } + """, logArgument) + ) + ); + } }