Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,44 +59,54 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
Preconditions.not(new GroovyFileChecker<>())
);
return Preconditions.check(preconditions, new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.Block visitBlock(J.Block originalBlock, ExecutionContext ctx) {
J.Block block = super.visitBlock(originalBlock, ctx);

AtomicReference<J.@Nullable Switch> originalSwitch = new AtomicReference<>();

int lastIndex = block.getStatements().size() - 1;
return block.withStatements(ListUtils.map(block.getStatements(), (index, statement) -> {
if (statement == originalSwitch.getAndSet(null)) {
doAfterVisit(new InlineVariable().getVisitor());
doAfterVisit(new SwitchExpressionYieldToArrow().getVisitor());
// We've already converted the switch/assignments to an assignment with a switch expression.
return null;
}
boolean supportsMultiCaseLabelsWithDefaultCase = false;

@Override
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
supportsMultiCaseLabelsWithDefaultCase = SwitchUtils.supportsMultiCaseLabelsWithDefaultCase(cu);
return super.visitCompilationUnit(cu, ctx);
}

@Override
public J.Block visitBlock(J.Block originalBlock, ExecutionContext ctx) {
J.Block block = super.visitBlock(originalBlock, ctx);

AtomicReference<J.@Nullable Switch> originalSwitch = new AtomicReference<>();

int lastIndex = block.getStatements().size() - 1;
return block.withStatements(ListUtils.map(block.getStatements(), (index, statement) -> {
if (statement == originalSwitch.getAndSet(null)) {
doAfterVisit(new InlineVariable().getVisitor());
doAfterVisit(new SwitchExpressionYieldToArrow().getVisitor());
// We've already converted the switch/assignments to an assignment with a switch expression.
return null;
}

if (index < lastIndex &&
statement instanceof J.VariableDeclarations &&
((J.VariableDeclarations) statement).getVariables().size() == 1 &&
!canHaveSideEffects(((J.VariableDeclarations) statement).getVariables().get(0).getInitializer()) &&
block.getStatements().get(index + 1) instanceof J.Switch) {
J.VariableDeclarations vd = (J.VariableDeclarations) statement;
J.Switch nextStatementSwitch = (J.Switch) block.getStatements().get(index + 1);

J.VariableDeclarations.NamedVariable originalVariable = vd.getVariables().get(0);
J.SwitchExpression newSwitchExpression = buildNewSwitchExpression(nextStatementSwitch, originalVariable);
if (newSwitchExpression != null) {
originalSwitch.set(nextStatementSwitch);
return vd
.withVariables(singletonList(originalVariable.getPadding().withInitializer(
JLeftPadded.<Expression>build(newSwitchExpression).withBefore(Space.SINGLE_SPACE))))
.withComments(ListUtils.concatAll(vd.getComments(), nextStatementSwitch.getComments()));
}
if (index < lastIndex &&
statement instanceof J.VariableDeclarations &&
((J.VariableDeclarations) statement).getVariables().size() == 1 &&
!canHaveSideEffects(((J.VariableDeclarations) statement).getVariables().get(0).getInitializer()) &&
block.getStatements().get(index + 1) instanceof J.Switch) {
J.VariableDeclarations vd = (J.VariableDeclarations) statement;
J.Switch nextStatementSwitch = (J.Switch) block.getStatements().get(index + 1);

if (supportsMultiCaseLabelsWithDefaultCase || !SwitchUtils.hasMultiCaseLabelsWithDefault(nextStatementSwitch.getCases().getStatements())) {
J.VariableDeclarations.NamedVariable originalVariable = vd.getVariables().get(0);
J.SwitchExpression newSwitchExpression = buildNewSwitchExpression(nextStatementSwitch, originalVariable);
if (newSwitchExpression != null) {
originalSwitch.set(nextStatementSwitch);
return vd
.withVariables(singletonList(originalVariable.getPadding().withInitializer(
JLeftPadded.<Expression>build(newSwitchExpression).withBefore(Space.SINGLE_SPACE))))
.withComments(ListUtils.concatAll(vd.getComments(), nextStatementSwitch.getComments()));
}
return statement;
}));
}
}
return statement;
}));
}

private J.@Nullable SwitchExpression buildNewSwitchExpression(J.Switch originalSwitch, J.VariableDeclarations.NamedVariable originalVariable) {
private J.@Nullable SwitchExpression buildNewSwitchExpression(J.Switch originalSwitch, J.VariableDeclarations.NamedVariable originalVariable) {
J.Identifier originalVariableId = originalVariable.getName();
AtomicBoolean isQualified = new AtomicBoolean(true);
AtomicBoolean isDefaultCaseAbsent = new AtomicBoolean(true);
Expand Down Expand Up @@ -185,7 +195,7 @@ public J.Block visitBlock(J.Block originalBlock, ExecutionContext ctx) {
return null;
}

private J.@Nullable Assignment extractAssignmentOfVariable(J maybeAssignment, J.Identifier variableId) {
private J.@Nullable Assignment extractAssignmentOfVariable(J maybeAssignment, org.openrewrite.java.tree.J.Identifier variableId) {
if (maybeAssignment instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) maybeAssignment;
if (assignment.getVariable() instanceof J.Identifier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
Preconditions.not(new GroovyFileChecker<>())
);
return Preconditions.check(preconditions, new JavaIsoVisitor<ExecutionContext>() {
boolean supportsMultiCaseLabelsWithDefaultCase = false;

@Override
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
supportsMultiCaseLabelsWithDefaultCase = SwitchUtils.supportsMultiCaseLabelsWithDefaultCase(cu);
return super.visitCompilationUnit(cu, ctx);
}

@Override
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
J.Block b = super.visitBlock(block, ctx);
Expand All @@ -76,7 +84,12 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
}

private boolean canConvertToSwitchExpression(J.Switch switchStatement) {
for (Statement statement : switchStatement.getCases().getStatements()) {
List<Statement> statements = switchStatement.getCases().getStatements();
if (!supportsMultiCaseLabelsWithDefaultCase && SwitchUtils.hasMultiCaseLabelsWithDefault(statements)) {
return false;
}

for (Statement statement : statements) {
if (!(statement instanceof J.Case)) {
return false;
}
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/org/openrewrite/java/migrate/lang/SwitchUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,45 @@
*/
package org.openrewrite.java.migrate.lang;

import org.openrewrite.java.marker.JavaVersion;
import org.openrewrite.java.tree.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;

import static java.util.stream.Collectors.toSet;

class SwitchUtils {
/**
* Checks if it's valid to use a switch expression construct like:
*
* <pre>
* <code>return switch (str) {
* case "foo" -> "Foo";
* case "ignore", default -> "Other";
* };</code>
* </pre>
* @param cu The compilation unit
* @return true if the used Java version supports this construct, false otherwise
*/
public static boolean supportsMultiCaseLabelsWithDefaultCase(J.CompilationUnit cu) {
Optional<JavaVersion> version = cu.getMarkers().findFirst(JavaVersion.class);
return version.isPresent() && version.get().getMajorReleaseVersion() >= 21;
}

public static boolean hasMultiCaseLabelsWithDefault(List<Statement> cases) {
if (!cases.isEmpty() && cases.get(cases.size() - 1) instanceof J.Case) {
J.Case lastCase = (J.Case) cases.get(cases.size() - 1);
if (lastCase.getCaseLabels().size() > 1) {
J lastCaseLabel = lastCase.getCaseLabels().get(lastCase.getCaseLabels().size() - 1);
return lastCaseLabel instanceof J.Identifier && "default".equals(((J.Identifier) lastCaseLabel).getSimpleName());
}
}
return false;
}

/**
* Checks if a switch statement covers all possible values of its selector.
* This is typically used to determine if a switch statement is "exhaustive" as per the Java language specification.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,4 +622,44 @@ void doFormat(String str) {
)
);
}

@Test
void notConvertWhenDefaultAsSecondLabelColonCase() {
rewriteRun(
//language=java
java(
"""
class A {
void doFormat(String str) {
String formatted = "initialValue";
switch (str) {
case "foo": formatted = "Foo"; break;
case "ignored", default: formatted = "unknown";
}
}
}
"""
)
);
}

@Test
void notConvertWhenDefaultAsSecondLabelArrowCase() {
rewriteRun(
//language=java
java(
"""
class B {
void doFormat(String str) {
String formatted = "initialValue";
switch (str) {
case "foo" -> formatted = "Foo";
case "ignored", default -> formatted = "Other";
}
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ String doFormat(String str) {
switch (str) {
case "foo": return "Foo";
case "bar": return "Bar";
case null, default: return "Other";
default: return "Other";
}
}
}
Expand All @@ -88,7 +88,7 @@ String doFormat(String str) {
return switch (str) {
case "foo" -> "Foo";
case "bar" -> "Bar";
case null, default -> "Other";
default -> "Other";
};
}
}
Expand Down Expand Up @@ -302,7 +302,26 @@ String process(String str) {
}

@Test
void supportMultiLabelWithNullSwitch() {
void doNotConvertMultiLabelWithNWithDefaultCaseWhenItsUnsupported() {
rewriteRun(
//language=java
java(
"""
class A {
String doFormat(String str) {
switch (str) {
case "foo": return "Foo";
case "ignored", default: return "Other";
}
}
}
"""
)
);
}

@Test
void supportMultiLabelWithNullSwitchIfPossible() {
rewriteRun(
version(
//language=java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void doFormat(TrafficLight light) {
}

@Test
void supportMultiLabelWithNullSwitch() {
void supportMultiLabelWithNullSwitchIfPossible() {
rewriteRun(
version(
//language=java
Expand Down