Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ private static class ValidExpressionExtractor
implements RowExpressionVisitor<Boolean, Boolean>
{
// Bind expression will complicate the lambda expression, we apply this optimization before DesugarLambdaRule. And if there are bind expression, skip
// SWITCH, COALESCE, IF are conditional expressions, some of their expressions may not be executed, but will always be executed if pulled out
private static final List<SpecialFormExpression.Form> UNSUPPORTED_TYPES = ImmutableList.of(SpecialFormExpression.Form.SWITCH, SpecialFormExpression.Form.BIND,
SpecialFormExpression.Form.COALESCE, SpecialFormExpression.Form.WHEN, SpecialFormExpression.Form.IF);
private static final List<SpecialFormExpression.Form> UNSUPPORTED_TYPES = ImmutableList.of(SpecialFormExpression.Form.BIND);
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final FunctionResolution functionResolution;
private final List<VariableReferenceExpression> inputVariables;
Expand Down Expand Up @@ -250,20 +248,48 @@ public Boolean visitCall(CallExpression call, Boolean context)
return false;
}

// For the conditional expressions, not all arguments will be evaluated, we only try to extract from the arguments which will always be executed
private static List<RowExpression> getValidArguments(SpecialFormExpression specialForm)
{
List<RowExpression> validArgument;
SpecialFormExpression.Form form = specialForm.getForm();
if (form.equals(SpecialFormExpression.Form.IF) || form.equals(SpecialFormExpression.Form.COALESCE) || form.equals(SpecialFormExpression.Form.WHEN)) {
validArgument = ImmutableList.of(specialForm.getArguments().get(0));
}
else if (form.equals(SpecialFormExpression.Form.SWITCH)) {
validArgument = ImmutableList.of(specialForm.getArguments().get(0), specialForm.getArguments().get(1));
}
else {
validArgument = specialForm.getArguments();
}
return validArgument;
}

// When expression cannot be pulled out, hence if we get a when expression, try to pull out its argument instead
private static RowExpression getArgumentOfWhen(RowExpression expression)
{
if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SpecialFormExpression.Form.WHEN)) {
return getArgumentOfWhen(((SpecialFormExpression) expression).getArguments().get(0));
}
return expression;
}

@Override
public Boolean visitSpecialForm(SpecialFormExpression specialForm, Boolean context)
{
if (UNSUPPORTED_TYPES.contains(specialForm.getForm())) {
return false;
}
Map<RowExpression, Boolean> validRowExpressionMap = specialForm.getArguments().stream().distinct().collect(toImmutableMap(identity(), x -> x.accept(this, context)));
List<RowExpression> validArguments = getValidArguments(specialForm);
Map<RowExpression, Boolean> validRowExpressionMap = specialForm.getArguments().stream().distinct().collect(toImmutableMap(identity(), x -> validArguments.contains(x) ? x.accept(this, context) : false));
if (context.equals(Boolean.TRUE)) {
boolean allArgumentsValid = validRowExpressionMap.values().stream().allMatch(x -> x.equals(Boolean.TRUE));
if (!allArgumentsValid) {
candidates.addAll(validRowExpressionMap.entrySet().stream()
.filter(x -> x.getValue().equals(Boolean.TRUE))
.filter(x -> isSupportedExpression(x.getKey()))
.map(Map.Entry::getKey)
.map(ValidExpressionExtractor::getArgumentOfWhen)
.filter(ValidExpressionExtractor::isSupportedExpression)
.collect(toImmutableList()));
}
return allArgumentsValid && determinismEvaluator.isDeterministic(specialForm);
Expand Down Expand Up @@ -300,7 +326,7 @@ public Boolean visitInputReference(InputReferenceExpression reference, Boolean c
}

// WHEN expression should only exist within SWITCH expression, and will throw exception in RowExpressionInterpreter, also no byte code generator for standalone WHEN expression
private boolean isSupportedExpression(RowExpression expression)
private static boolean isSupportedExpression(RowExpression expression)
{
return expression instanceof CallExpression || (expression instanceof SpecialFormExpression && !((SpecialFormExpression) expression).getForm().equals(SpecialFormExpression.Form.WHEN));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.Session;
import com.facebook.presto.common.block.IntArrayBlock;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionMetadata;
Expand Down Expand Up @@ -52,6 +53,7 @@
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubscriptExpression;
Expand Down Expand Up @@ -237,6 +239,34 @@ protected Boolean visitIsNotNullPredicate(IsNotNullPredicate expected, RowExpres
return process(expected.getValue(), ((SpecialFormExpression) argument).getArguments().get(0));
}

@Override
protected Boolean visitSearchedCaseExpression(SearchedCaseExpression node, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !((SpecialFormExpression) actual).getForm().equals(SWITCH)) {
return false;
}
SpecialFormExpression specialForm = (SpecialFormExpression) actual;
int argumentSize = node.getWhenClauses().size() + 1;
if (node.getDefaultValue().isPresent()) {
++argumentSize;
}
if (specialForm.getArguments().size() != argumentSize) {
return false;
}
if (!specialForm.getArguments().get(0).equals(constant(true, BooleanType.BOOLEAN))) {
return false;
}
for (int i = 0; i < node.getWhenClauses().size(); ++i) {
if (!process(node.getWhenClauses().get(i), specialForm.getArguments().get(i + 1))) {
return false;
}
}
if (node.getDefaultValue().isPresent()) {
return process(node.getDefaultValue().get(), specialForm.getArguments().get(argumentSize - 1));
}
return true;
}

@Override
protected Boolean visitInPredicate(InPredicate expected, RowExpression actual)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,51 @@ public void testSwitchWhenExpression()
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(arr, x -> concat(case when arr2 is null then '*' when contains(arr2, x) then '+' else ' ' end, x))")).build(),
p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR))));
}).matches(
project(
ImmutableMap.of("expr", expression("transform(arr, x -> concat(case when expr_0 then '*' when contains(arr2, x) then '+' else ' ' end, x))")),
project(ImmutableMap.of("expr_0", expression("arr2 is null")),
values("arr", "arr2"))));
}

// Candidate expression for extract is the second when expression, hence skip
@Test
public void testInvalidSwitchWhenExpression()
{
tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
.setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
.on(p ->
{
p.variable("arr", new ArrayType(VARCHAR));
p.variable("arr2", new ArrayType(VARCHAR));
return p.project(
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(arr, x -> concat(case when contains(arr2, x) then '*' when arr2 is null then '+' else ' ' end, x))")).build(),
p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR))));
}).doesNotFire();
}

@Test
public void testCaseWhenExpression()
{
tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
.setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
.on(p ->
{
p.variable("arr", new ArrayType(VARCHAR));
p.variable("arr2", new ArrayType(VARCHAR));
p.variable("col1");
return p.project(
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(arr, x -> concat(case (col1 > 2) when arr2 is null then '*' when contains(arr2, x) then '+' else ' ' end, x))")).build(),
p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR)), p.variable("col1")));
}).matches(
project(
ImmutableMap.of("expr", expression("transform(arr, x -> concat(case expr_1 when expr_0 then '*' when contains(arr2, x) then '+' else ' ' end, x))")),
project(ImmutableMap.of("expr_0", expression("arr2 is null"), "expr_1", expression("col1>2")),
values("arr", "arr2", "col1"))));
}

@Test
public void testConditionalExpression()
{
Expand All @@ -224,4 +266,59 @@ public void testConditionalExpression()
p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT))));
}).doesNotFire();
}

@Test
public void testIfExpressionOnCondition()
{
tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
.setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
.on(p ->
{
p.variable("col1", new ArrayType(BOOLEAN));
p.variable("col2", new ArrayType(BIGINT));
p.variable("col3");
return p.project(
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(col1, x -> if(col3 > 2, col2[2], 0))")).build(),
p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
}).matches(
project(
ImmutableMap.of("expr", expression("transform(col1, x -> if(greater_than, col2[2], 0))")),
project(ImmutableMap.of("greater_than", expression("col3>2")),
values("col1", "col2", "col3"))));
}

@Test
public void testIfExpressionOnValue()
{
tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
.setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
.on(p ->
{
p.variable("col1", new ArrayType(BOOLEAN));
p.variable("col2", new ArrayType(BIGINT));
p.variable("col3");
return p.project(
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(col1, x -> if(x, col3 - 2, 0))")).build(),
p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
}).doesNotFire();
}

@Test
public void testSubscriptExpression()
{
tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
.setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
.on(p ->
{
p.variable("col1", new ArrayType(BOOLEAN));
p.variable("col2", new ArrayType(BIGINT));
p.variable("col3");
return p.project(
Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
"transform(col1, x -> col2[2])")).build(),
p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
}).doesNotFire();
}
}